Skip to content

Commit 0dea0c9

Browse files
committed
Allow disabling CRF in structure prediction
1 parent 86ab99f commit 0dea0c9

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

sota_extractor2/models/structure/structure_predictor.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,20 @@ def df2tl(self, df):
6666
df = df[text_cols]
6767
return TextList.from_df(df, cols=text_cols)
6868

69-
def get_features(self, evidences):
69+
def get_features(self, evidences, use_crf=True):
70+
if use_crf:
71+
learner = self.learner
72+
else:
73+
learner = self._full_learner
7074
if len(evidences):
7175
tl = self.df2tl(evidences)
72-
self.learner.data.add_test(tl)
76+
learner.data.add_test(tl)
7377

74-
preds, _ = self.learner.get_preds(DatasetType.Test, ordered=True)
78+
preds, _ = learner.get_preds(DatasetType.Test, ordered=True)
7579
return preds.cpu().numpy()
76-
return np.zeros((0, n_ulmfit_features))
80+
return np.zeros((0, n_ulmfit_features if use_crf else n_classes))
7781

78-
def to_tables(self, df, transpose=False):
82+
def to_tables(self, df, transpose=False, n_ulmfit_features=n_ulmfit_features):
7983
X_tables = []
8084
Y_tables = []
8185
ids = []
@@ -127,12 +131,12 @@ def merge_with_preds(self, df, preds):
127131
return list(zip(ext_id[0] + "/" + ext_id[1], ext_id[2].astype(int), ext_id[3].astype(int),
128132
preds, df.text, df.cell_content, df.cell_layout, df.cell_styles, df.cell_reference, df.label))
129133

130-
def merge_all_with_preds(self, df, df_num, preds):
134+
def merge_all_with_preds(self, df, df_num, preds, use_crf=True):
131135
columns = ["table_id", "row", "col", "features", "text", "cell_content", "cell_layout",
132136
"cell_styles", "cell_reference", "label"]
133137

134138
alpha = self.merge_with_preds(df, preds)
135-
nums = self.merge_with_preds(df_num, np.zeros((len(df_num), n_ulmfit_features)))
139+
nums = self.merge_with_preds(df_num, np.zeros((len(df_num), n_ulmfit_features if use_crf else n_classes)))
136140

137141
df1 = pd.DataFrame(alpha, columns=columns)
138142
df2 = pd.DataFrame(nums, columns=columns)
@@ -156,13 +160,16 @@ def format_predictions(self, tables_preds, test_ids):
156160
labels[r, c]])
157161
return pd.DataFrame(flat, columns=["paper", "table", "row", "col", "predicted_tags"])
158162

159-
def predict_tags(self, raw_evidences):
163+
def predict_tags(self, raw_evidences, use_crf=True):
160164
evidences, evidences_num = self.keep_alphacells(self.preprocess_df(raw_evidences))
161165
pipeline_logger(f"{TableStructurePredictor.step}::evidences_split", evidences=evidences, evidences_num=evidences_num)
162-
features = self.get_features(evidences)
163-
df = self.merge_all_with_preds(evidences, evidences_num, features)
164-
tables, contents, ids = self.to_tables(df)
165-
preds = self.crf.predict(tables)
166+
features = self.get_features(evidences, use_crf)
167+
df = self.merge_all_with_preds(evidences, evidences_num, features, use_crf)
168+
tables, contents, ids = self.to_tables(df, n_ulmfit_features=n_ulmfit_features if use_crf else n_classes)
169+
if use_crf:
170+
preds = self.crf.predict(tables)
171+
else:
172+
preds = [table[..., :n_classes].argmax(axis=-1) for table in tables]
166173
return self.format_predictions(preds, ids)
167174

168175
# todo: consider adding sota/ablation information
@@ -179,10 +186,10 @@ def label_table(self, paper, table, annotations, in_place):
179186
return table
180187

181188
# todo: take EvidenceExtractor in constructor
182-
def label_tables(self, paper, tables, raw_evidences, in_place=False):
189+
def label_tables(self, paper, tables, raw_evidences, in_place=False, use_crf=True):
183190
pipeline_logger(f"{TableStructurePredictor.step}::label_tables", paper=paper, tables=tables, raw_evidences=raw_evidences)
184191
if len(raw_evidences):
185-
tags = self.predict_tags(raw_evidences)
192+
tags = self.predict_tags(raw_evidences, use_crf)
186193
annotations = dict(list(tags.groupby(by=["paper", "table"])))
187194
else:
188195
annotations = {} # just deep-copy all tables

0 commit comments

Comments
 (0)