Skip to content

Commit 6faf6e3

Browse files
committed
Add ULMFiTExperiment
* add ULMFiTExperiment * add caching to evidence extraction * move older Experiment params to NBSVMExperiment
1 parent 184224c commit 6faf6e3

File tree

5 files changed

+230
-47
lines changed

5 files changed

+230
-47
lines changed

sota_extractor2/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from pathlib import Path
2+
from pathlib import Path
33

44
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
55
datefmt='%m/%d/%Y %H:%M:%S',

sota_extractor2/data/structure.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,58 @@ def empty_fragment(paper_id):
4646
return fragment
4747

4848

49-
def fetch_evidence(cell_content, cell_reference, paper_id, table_name, row, col, paper_limit=10, corpus_limit=10):
49+
def normalize_query(query):
50+
if isinstance(query, list):
51+
return tuple(normalize_query(x) for x in query)
52+
if isinstance(query, dict):
53+
return tuple([(normalize_query(k), normalize_query(v)) for k,v in query.items()])
54+
return query
55+
56+
_evidence_cache = {}
57+
_cache_miss = 0
58+
_cache_hit = 0
59+
def get_cached_or_execute(query):
60+
global _evidence_cache, _cache_hit, _cache_miss
61+
n = normalize_query(query.to_dict())
62+
if n not in _evidence_cache:
63+
_evidence_cache[n] = list(query)
64+
_cache_miss += 1
65+
else:
66+
_cache_hit += 1
67+
return _evidence_cache[n]
68+
69+
70+
def fetch_evidence(cell_content, cell_reference, paper_id, table_name, row, col, paper_limit=10, corpus_limit=10,
71+
cache=False):
5072
if not filter_cells(cell_content):
5173
return [empty_fragment(paper_id)]
5274
cell_content = clear_cell(cell_content)
5375
if cell_content == "" and cell_reference == "":
5476
return [empty_fragment(paper_id)]
5577

78+
cached_query = get_cached_or_execute if cache else lambda x: x
5679
evidence_query = Fragment.search().highlight(
5780
'text', pre_tags="<b>", post_tags="</b>", fragment_size=400)
5881
cell_content = cell_content.replace("\xa0", " ")
5982
query = {
6083
"query": cell_content,
6184
"slop": 2
6285
}
63-
paper_fragments = list(evidence_query
86+
paper_fragments = list(cached_query(evidence_query
6487
.filter('term', paper_id=paper_id)
65-
.query('match_phrase', text=query)[:paper_limit])
88+
.query('match_phrase', text=query)[:paper_limit]))
6689
if cell_reference != "":
67-
reference_fragments = list(evidence_query
90+
reference_fragments = list(cached_query(evidence_query
6891
.filter('term', paper_id=paper_id)
6992
.query('match_phrase', text={
7093
"query": cell_reference,
7194
"slop": 1
72-
})[:paper_limit])
95+
})[:paper_limit]))
7396
else:
7497
reference_fragments = []
75-
other_fagements = list(evidence_query
98+
other_fagements = list(cached_query(evidence_query
7699
.exclude('term', paper_id=paper_id)
77-
.query('match_phrase', text=query)[:corpus_limit])
100+
.query('match_phrase', text=query)[:corpus_limit]))
78101

79102
ext_id = f"{paper_id}/{table_name}/{row}.{col}"
80103
####print(f"{ext_id} |{cell_content}|: {len(paper_fragments)} paper fragments, {len(reference_fragments)} reference fragments, {len(other_fagements)} other fragments")
@@ -137,22 +160,23 @@ def filter_cells(cell_content):
137160
interesting_types = ["model-paper", "model-best", "model-competing", "dataset", "dataset-sub", "dataset-task"]
138161

139162

140-
def evidence_for_table(paper_id, table, paper_limit, corpus_limit):
163+
def evidence_for_table(paper_id, table, paper_limit, corpus_limit, cache=False):
141164
records = [
142165
record
143166
for cell in consume_cells(table)
144167
for evidence in fetch_evidence(cell.vals[0], cell.vals[2], paper_id=paper_id, table_name=table.name,
145-
row=cell.row, col=cell.col, paper_limit=paper_limit, corpus_limit=corpus_limit)
168+
row=cell.row, col=cell.col, paper_limit=paper_limit, corpus_limit=corpus_limit,
169+
cache=cache)
146170
for record in create_evidence_records(evidence, cell, paper_id=paper_id, table=table)
147171
]
148172
df = pd.DataFrame.from_records(records, columns=evidence_columns)
149173
return df
150174

151175

152-
def prepare_data(tables, csv_path):
176+
def prepare_data(tables, csv_path, cache=False):
153177
data = [evidence_for_table(table.paper_id, table,
154178
paper_limit=100,
155-
corpus_limit=20) for table in progress_bar(tables)]
179+
corpus_limit=20, cache=cache) for table in progress_bar(tables)]
156180
if len(data):
157181
df = pd.concat(data)
158182
else:

sota_extractor2/models/structure/experiment.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,21 @@ class Labels(Enum):
3434
# put here to avoid recompiling, used only in _limit_context
3535
elastic_tag_split_re = re.compile("(<b>.*?</b>)")
3636

37+
# e = Experiment(remove_num=False, drop_duplicates=False, vectorizer='count',
38+
# this_paper=True, merge_fragments=True, merge_type='concat',
39+
# evidence_source='text_highlited', split_btags=True, fixed_tokenizer=True,
40+
# fixed_this_paper=True, mask=False, evidence_limit=None, context_tokens=None,
41+
# analyzer='word', lowercase=True, class_weight='balanced', multinomial_type='multinomial',
42+
# solver='lbfgs', C=0.1, dual=False, penalty='l2', ngram_range=[1, 3],
43+
# min_df=10, max_df=0.9, max_iter=1000, results={}, has_model=False)
44+
45+
# ULMFiT related parameters
46+
# remove_num, drop_duplicates, this_paper, merge_fragments, merge_type, evidence_source, split_btags
47+
# fixed_tokenizer?, fixed_this_paper (remove), mask, evidence_limit, context_tokens, lowercase
48+
# class_weight? (consider adding support),
49+
3750
@dataclass
3851
class Experiment:
39-
vectorizer: str = "tfidf"
4052
this_paper: bool = False
4153
merge_fragments: bool = False
4254
merge_type: str = "concat" # "concat", "vote_maj", "vote_avg", "vote_max"
@@ -47,23 +59,11 @@ class Experiment:
4759
mask: bool = False # if True and evidence_source = "text_highlited", replace <b>...</b> with xxmask
4860
evidence_limit: int = None # maximum number of evidences per cell (grouped by (ext_id, this_paper))
4961
context_tokens: int = None # max. number of words before <b> and after </b>
50-
analyzer: str = "word" # "char", "word" or "char_wb"
5162
lowercase: bool = True
5263
remove_num: bool = True
5364
drop_duplicates: bool = True
5465
mark_this_paper: bool = False
5566

56-
class_weight: str = None
57-
multinomial_type: str = "manual" # "manual", "ovr", "multinomial"
58-
solver: str = "liblinear" # 'lbfgs' - large, liblinear for small datasets
59-
C: float = 4.0
60-
dual: bool = True
61-
penalty: str = "l2"
62-
ngram_range: tuple = (1, 2)
63-
min_df: int = 3
64-
max_df: float = 0.9
65-
max_iter: int = 1000
66-
6767
results: dict = dataclasses.field(default_factory=dict)
6868

6969
has_model: bool = False # either there's already pretrained model or it's a saved experiment and there's a saved model as well
@@ -78,29 +78,39 @@ def _get_next_exp_name(self, dir_path):
7878
return dir_path / name
7979
raise Exception("You have too many files in this dir, really!")
8080

81-
def _save_model(self, path):
81+
@staticmethod
82+
def _dump_pickle(obj, path):
8283
with open(path, 'wb') as f:
83-
pickle.dump(self._model, f)
84+
pickle.dump(obj, f)
8485

85-
def _load_model(self, path):
86+
@staticmethod
87+
def _load_pickle(path):
8688
with open(path, 'rb') as f:
87-
self._model = pickle.load(f)
88-
return self._model
89+
return pickle.load(f)
90+
91+
def _save_model(self, path):
92+
self._dump_pickle(self._model, path)
93+
94+
def _load_model(self, path):
95+
self._model = self._load_pickle(path)
96+
return self._model
8997

9098
def load_model(self):
9199
path = self._path.parent / f"{self._path.stem}.model"
92100
return self._load_model(path)
93101

102+
def save_model(self, path):
103+
if hasattr(self, "_model"):
104+
self._save_model(path)
105+
94106
def save(self, dir_path):
95107
dir_path = Path(dir_path)
96108
dir_path.mkdir(exist_ok=True, parents=True)
97109
filename = self._get_next_exp_name(dir_path)
98110
j = dataclasses.asdict(self)
99111
with open(filename, "wt") as f:
100112
json.dump(j, f)
101-
if hasattr(self, "_model"):
102-
fn = filename.stem
103-
self._save_model(dir_path / f"{fn}.model")
113+
self.save_model(dir_path / f"{filename.stem}.model")
104114
return filename.name
105115

106116
def to_df(self):
@@ -119,12 +129,13 @@ def new_experiment(self, **kwargs):
119129
def update_results(self, **kwargs):
120130
self.results.update(**kwargs)
121131

122-
def get_trained_model(self, train_df):
123-
nbsvm = NBSVM(experiment=self)
124-
nbsvm.fit(train_df["text"], train_df["label"])
125-
self._model = nbsvm
132+
def train_model(self, train_df, valid_df):
133+
raise NotImplementedError("train_model should be implemented in subclass")
134+
135+
def get_trained_model(self, train_df, valid_df):
136+
self._model = self.train_model(train_df, valid_df)
126137
self.has_model = True
127-
return nbsvm
138+
return self._model
128139

129140
def _limit_context(self, text):
130141
parts = elastic_tag_split_re.split(text)
@@ -301,3 +312,23 @@ def experiments_to_df(cls, exps):
301312
dfs = [e.to_df() for e in exps]
302313
df = pd.concat(dfs)
303314
return df
315+
316+
@dataclass
317+
class NBSVMExperiment(Experiment):
318+
vectorizer: str = "tfidf"
319+
analyzer: str = "word" # "char", "word" or "char_wb"
320+
class_weight: str = None
321+
multinomial_type: str = "manual" # "manual", "ovr", "multinomial"
322+
solver: str = "liblinear" # 'lbfgs' - large, liblinear for small datasets
323+
C: float = 4.0
324+
dual: bool = True
325+
penalty: str = "l2"
326+
ngram_range: tuple = (1, 2)
327+
min_df: int = 3
328+
max_df: float = 0.9
329+
max_iter: int = 1000
330+
331+
def train_model(self, train_df, valid_df=None):
332+
nbsvm = NBSVM(experiment=self)
333+
nbsvm.fit(train_df["text"], train_df["label"])
334+
return nbsvm

sota_extractor2/models/structure/structure_predictor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pandas as pd
44
import numpy as np
55
import pickle
6-
from .experiment import Experiment, Labels, label_map
6+
from .experiment import Labels, label_map
7+
from .ulmfit_experiment import ULMFiTExperiment
78
import re
89
from .ulmfit import ULMFiT_SP
910
from ...pipeline_logger import pipeline_logger
@@ -45,13 +46,11 @@ def __init__(self, path, file, crf_path=None, crf_model="crf.pkl",
4546
self.crf = load_crf(crf_path / crf_model)
4647

4748
# todo: clean Experiment from older approaches
48-
self._e = Experiment(remove_num=False, drop_duplicates=False, vectorizer='count',
49-
this_paper=True, merge_fragments=True, merge_type='concat',
50-
evidence_source='text_highlited', split_btags=True, fixed_tokenizer=True,
51-
fixed_this_paper=True, mask=False, evidence_limit=None, context_tokens=None,
52-
analyzer='word', lowercase=True, class_weight='balanced', multinomial_type='multinomial',
53-
solver='lbfgs', C=0.1, dual=False, penalty='l2', ngram_range=[1, 3],
54-
min_df=10, max_df=0.9, max_iter=1000, results={}, has_model=False)
49+
self._e = ULMFiTExperiment(remove_num=False, drop_duplicates=False,
50+
this_paper=True, merge_fragments=True, merge_type='concat',
51+
evidence_source='text_highlited', split_btags=True, fixed_tokenizer=True,
52+
fixed_this_paper=True, mask=False, evidence_limit=None, context_tokens=None,
53+
lowercase=True)
5554

5655
def preprocess_df(self, raw_df):
5756
return self._e.transform_df(raw_df)
@@ -140,7 +139,7 @@ def merge_all_with_preds(self, df, df_num, preds):
140139
df2.label = n_classes
141140
return df1.append(df2, ignore_index=True)
142141

143-
142+
# todo: fix numeric cells being labelled as meta / other
144143
def format_predictions(self, tables_preds, test_ids):
145144
num2label = {v: k for k, v in label_map.items()}
146145
num2label[0] = "table-meta"
@@ -172,6 +171,7 @@ def label_table(self, paper, table, annotations, in_place):
172171
ext_id = (paper.paper_id, table.name)
173172
if ext_id in annotations:
174173
for _, entry in annotations[ext_id].iterrows():
174+
# todo: add model-ensemble support
175175
structure.iloc[entry.row, entry.col] = entry.predicted_tags if entry.predicted_tags != "model-paper" else "model-best"
176176
if not in_place:
177177
table = deepcopy(table)

0 commit comments

Comments
 (0)