@@ -66,16 +66,20 @@ def df2tl(self, df):
66
66
df = df [text_cols ]
67
67
return TextList .from_df (df , cols = text_cols )
68
68
69
- def get_features (self , evidences ):
69
+ def get_features (self , evidences , use_crf = True ):
70
+ if use_crf :
71
+ learner = self .learner
72
+ else :
73
+ learner = self ._full_learner
70
74
if len (evidences ):
71
75
tl = self .df2tl (evidences )
72
- self . learner .data .add_test (tl )
76
+ learner .data .add_test (tl )
73
77
74
- preds , _ = self . learner .get_preds (DatasetType .Test , ordered = True )
78
+ preds , _ = learner .get_preds (DatasetType .Test , ordered = True )
75
79
return preds .cpu ().numpy ()
76
- return np .zeros ((0 , n_ulmfit_features ))
80
+ return np .zeros ((0 , n_ulmfit_features if use_crf else n_classes ))
77
81
78
- def to_tables (self , df , transpose = False ):
82
+ def to_tables (self , df , transpose = False , n_ulmfit_features = n_ulmfit_features ):
79
83
X_tables = []
80
84
Y_tables = []
81
85
ids = []
@@ -127,12 +131,12 @@ def merge_with_preds(self, df, preds):
127
131
return list (zip (ext_id [0 ] + "/" + ext_id [1 ], ext_id [2 ].astype (int ), ext_id [3 ].astype (int ),
128
132
preds , df .text , df .cell_content , df .cell_layout , df .cell_styles , df .cell_reference , df .label ))
129
133
130
- def merge_all_with_preds (self , df , df_num , preds ):
134
+ def merge_all_with_preds (self , df , df_num , preds , use_crf = True ):
131
135
columns = ["table_id" , "row" , "col" , "features" , "text" , "cell_content" , "cell_layout" ,
132
136
"cell_styles" , "cell_reference" , "label" ]
133
137
134
138
alpha = self .merge_with_preds (df , preds )
135
- nums = self .merge_with_preds (df_num , np .zeros ((len (df_num ), n_ulmfit_features )))
139
+ nums = self .merge_with_preds (df_num , np .zeros ((len (df_num ), n_ulmfit_features if use_crf else n_classes )))
136
140
137
141
df1 = pd .DataFrame (alpha , columns = columns )
138
142
df2 = pd .DataFrame (nums , columns = columns )
@@ -156,13 +160,16 @@ def format_predictions(self, tables_preds, test_ids):
156
160
labels [r , c ]])
157
161
return pd .DataFrame (flat , columns = ["paper" , "table" , "row" , "col" , "predicted_tags" ])
158
162
159
- def predict_tags (self , raw_evidences ):
163
+ def predict_tags (self , raw_evidences , use_crf = True ):
160
164
evidences , evidences_num = self .keep_alphacells (self .preprocess_df (raw_evidences ))
161
165
pipeline_logger (f"{ TableStructurePredictor .step } ::evidences_split" , evidences = evidences , evidences_num = evidences_num )
162
- features = self .get_features (evidences )
163
- df = self .merge_all_with_preds (evidences , evidences_num , features )
164
- tables , contents , ids = self .to_tables (df )
165
- preds = self .crf .predict (tables )
166
+ features = self .get_features (evidences , use_crf )
167
+ df = self .merge_all_with_preds (evidences , evidences_num , features , use_crf )
168
+ tables , contents , ids = self .to_tables (df , n_ulmfit_features = n_ulmfit_features if use_crf else n_classes )
169
+ if use_crf :
170
+ preds = self .crf .predict (tables )
171
+ else :
172
+ preds = [table [..., :n_classes ].argmax (axis = - 1 ) for table in tables ]
166
173
return self .format_predictions (preds , ids )
167
174
168
175
# todo: consider adding sota/ablation information
@@ -179,10 +186,10 @@ def label_table(self, paper, table, annotations, in_place):
179
186
return table
180
187
181
188
# todo: take EvidenceExtractor in constructor
182
- def label_tables (self , paper , tables , raw_evidences , in_place = False ):
189
+ def label_tables (self , paper , tables , raw_evidences , in_place = False , use_crf = True ):
183
190
pipeline_logger (f"{ TableStructurePredictor .step } ::label_tables" , paper = paper , tables = tables , raw_evidences = raw_evidences )
184
191
if len (raw_evidences ):
185
- tags = self .predict_tags (raw_evidences )
192
+ tags = self .predict_tags (raw_evidences , use_crf )
186
193
annotations = dict (list (tags .groupby (by = ["paper" , "table" ])))
187
194
else :
188
195
annotations = {} # just deep-copy all tables
0 commit comments