@@ -42,15 +42,17 @@ def __init__(self, path, file, crf_path=None, crf_model="crf.pkl",
42
42
self ._full_learner = deepcopy (self .learner )
43
43
self .learner .model = cut_ulmfit_head (self .learner .model )
44
44
self .learner .loss_func = None
45
+
46
+ #todo: make CRF optional
45
47
crf_path = Path (path ) if crf_path is None else Path (crf_path )
46
48
self .crf = load_crf (crf_path / crf_model )
47
49
48
50
# todo: clean Experiment from older approaches
49
51
self ._e = ULMFiTExperiment (remove_num = False , drop_duplicates = False ,
50
52
this_paper = True , merge_fragments = True , merge_type = 'concat' ,
51
53
evidence_source = 'text_highlited' , split_btags = True , fixed_tokenizer = True ,
52
- fixed_this_paper = True , mask = False , evidence_limit = None , context_tokens = None ,
53
- lowercase = True )
54
+ fixed_this_paper = True , mask = True , evidence_limit = None , context_tokens = None ,
55
+ lowercase = True , drop_mult = 0.15 , fp16 = True , train_on_easy = False )
54
56
55
57
def preprocess_df (self , raw_df ):
56
58
return self ._e .transform_df (raw_df )
@@ -169,7 +171,11 @@ def predict_tags(self, raw_evidences, use_crf=True):
169
171
if use_crf :
170
172
preds = self .crf .predict (tables )
171
173
else :
172
- preds = [table [..., :n_classes ].argmax (axis = - 1 ) for table in tables ]
174
+ preds = []
175
+ for table in tables :
176
+ p = table [..., :n_classes ].argmax (axis = - 1 )
177
+ p [table [..., :n_classes ].max (axis = - 1 ) == 0.0 ] = n_classes
178
+ preds .append (p )
173
179
return self .format_predictions (preds , ids )
174
180
175
181
# todo: consider adding sota/ablation information
0 commit comments