Skip to content

Commit 3c816ed

Browse files
committed
Make CRF phase optional
1 parent 1e128f1 commit 3c816ed

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

sota_extractor2/models/structure/structure_predictor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ def __init__(self, path, file, crf_path=None, crf_model="crf.pkl",
4242
self._full_learner = deepcopy(self.learner)
4343
self.learner.model = cut_ulmfit_head(self.learner.model)
4444
self.learner.loss_func = None
45+
46+
#todo: make CRF optional
4547
crf_path = Path(path) if crf_path is None else Path(crf_path)
4648
self.crf = load_crf(crf_path / crf_model)
4749

4850
# todo: clean Experiment from older approaches
4951
self._e = ULMFiTExperiment(remove_num=False, drop_duplicates=False,
5052
this_paper=True, merge_fragments=True, merge_type='concat',
5153
evidence_source='text_highlited', split_btags=True, fixed_tokenizer=True,
52-
fixed_this_paper=True, mask=False, evidence_limit=None, context_tokens=None,
53-
lowercase=True)
54+
fixed_this_paper=True, mask=True, evidence_limit=None, context_tokens=None,
55+
lowercase=True, drop_mult=0.15, fp16=True, train_on_easy=False)
5456

5557
def preprocess_df(self, raw_df):
5658
return self._e.transform_df(raw_df)
@@ -169,7 +171,11 @@ def predict_tags(self, raw_evidences, use_crf=True):
169171
if use_crf:
170172
preds = self.crf.predict(tables)
171173
else:
172-
preds = [table[..., :n_classes].argmax(axis=-1) for table in tables]
174+
preds = []
175+
for table in tables:
176+
p = table[..., :n_classes].argmax(axis=-1)
177+
p[table[..., :n_classes].max(axis=-1) == 0.0] = n_classes
178+
preds.append(p)
173179
return self.format_predictions(preds, ids)
174180

175181
# todo: consider adding sota/ablation information

0 commit comments

Comments
 (0)