Skip to content

Commit 7de755c

Browse files
committed
Add metric class and emtpy evidences
1 parent 1db8c1f commit 7de755c

File tree

2 files changed

+50
-17
lines changed

2 files changed

+50
-17
lines changed

sota_extractor2/data/structure.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def clear_cell(s):
3232
return s
3333

3434

35+
def empty_fragment(paper_id):
36+
fragment = Fragment(paper_id=paper_id)
37+
fragment.meta['highlight'] = {'text': ['']}
38+
return fragment
39+
40+
3541
def fetch_evidence(cell_content, cell_reference, paper_id, paper_limit=10, corpus_limit=10):
3642
cell_content = clear_cell(cell_content)
3743
if cell_content == "" and cell_reference == "":
@@ -61,6 +67,8 @@ def fetch_evidence(cell_content, cell_reference, paper_id, paper_limit=10, corpu
6167
.query('match_phrase', text=query)[:corpus_limit])
6268
if not len(paper_fragments) and not len(reference_fragments) and not len(other_fagements):
6369
print(f"No evidences for '{cell_content}' of {paper_id}")
70+
if not len(paper_fragments) and not len(reference_fragments):
71+
paper_fragments = [empty_fragment(paper_id)]
6472
return paper_fragments + reference_fragments + other_fagements
6573

6674
fix_refs_re = re.compile('\(\?\)|\s[?]+(\s|$)')
@@ -124,12 +132,13 @@ def get_limits(cell_type):
124132

125133

126134
def prepare_data(tables, csv_path, limit_type='interesting'):
127-
df = pd.concat([evidence_for_table(table,
135+
df = pd.concat([evidence_for_table(table,
128136
paper_limit=100,
129137
corpus_limit=20,
130138
limit_type=limit_type) for table in progress_bar(tables)])
131-
df = df.drop_duplicates(
132-
["cell_content", "text_highlited", "cell_type", "this_paper"])
139+
#moved to experiment preprocessing
140+
#df = df.drop_duplicates(
141+
# ["cell_content", "text_highlited", "cell_type", "this_paper"])
133142
print("Number of text fragments ", len(df))
134143
csv_path.parent.mkdir(parents=True, exist_ok=True)
135144
df.to_csv(csv_path, index=None)

sota_extractor2/models/structure/experiment.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ class Labels(Enum):
1616
DATASET=1
1717
PAPER_MODEL=2
1818
COMPETING_MODEL=3
19+
METRIC=4
1920

2021
label_map = {
2122
"dataset": Labels.DATASET.value,
2223
"dataset-sub": Labels.DATASET.value,
2324
"model-paper": Labels.PAPER_MODEL.value,
2425
"model-best": Labels.PAPER_MODEL.value,
25-
"model-competing": Labels.COMPETING_MODEL.value
26+
"model-competing": Labels.COMPETING_MODEL.value,
27+
"dataset-metric": Labels.METRIC.value
2628
}
2729

2830
# put here to avoid recompiling, used only in _limit_context
@@ -43,6 +45,9 @@ class Experiment:
4345
context_tokens: int = None # max. number of words before <b> and after </b>
4446
analyzer: str = "word" # "char", "word" or "char_wb"
4547
lowercase: bool = True
48+
remove_num: bool = True
49+
drop_duplicates: bool = True
50+
mark_this_paper: bool = False
4651

4752
class_weight: str = None
4853
multinomial_type: str = "manual" # "manual", "ovr", "multinomial"
@@ -142,6 +147,8 @@ def _limit_context(self, text):
142147
def _transform_df(self, df):
143148
if self.merge_type not in ["concat", "vote_maj", "vote_avg", "vote_max"]:
144149
raise Exception(f"merge_type must be one of concat, vote_maj, vote_avg, vote_max, but {self.merge_type} was given")
150+
if self.mark_this_paper and (self.merge_type != "concat" or self.this_paper):
151+
raise Exception("merge_type must be 'concat' and this_paper must be false")
145152
#df = df[df["cell_type"] != "table-meta"] # otherwise we get precision 0 on test set
146153
if self.evidence_limit is not None:
147154
df = df.groupby(by=["ext_id", "this_paper"]).head(self.evidence_limit)
@@ -154,14 +161,25 @@ def _transform_df(self, df):
154161
df["text"] = df[self.evidence_source].replace(re.compile("<b>.*?</b>"), " xxmask ")
155162
else:
156163
df["text"] = df[self.evidence_source]
157-
158164
elif self.mask:
159165
raise Exception("Masking with evidence_source='text' makes no sense")
160-
if not self.fixed_this_paper:
166+
167+
if self.mark_this_paper:
168+
df = df.groupby(by=["ext_id", "cell_content", "cell_type", "this_paper"]).text.apply(
169+
lambda x: "\n".join(x.values)).reset_index()
170+
this_paper_map = {
171+
True: "this paper",
172+
False: "other paper"
173+
}
174+
df.text = "xxfld 3 " + df.this_paper.apply(this_paper_map.get) + " " + df.text
175+
df = df.groupby(by=["ext_id", "cell_content", "cell_type"]).text.apply(
176+
lambda x: " ".join(x.values)).reset_index()
177+
elif not self.fixed_this_paper:
161178
if self.merge_fragments and self.merge_type == "concat":
162179
df = df.groupby(by=["ext_id", "cell_content", "cell_type", "this_paper"]).text.apply(
163180
lambda x: "\n".join(x.values)).reset_index()
164-
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
181+
if self.drop_duplicates:
182+
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
165183
if self.this_paper:
166184
df = df[df.this_paper]
167185
else:
@@ -170,13 +188,15 @@ def _transform_df(self, df):
170188
if self.merge_fragments and self.merge_type == "concat":
171189
df = df.groupby(by=["ext_id", "cell_content", "cell_type"]).text.apply(
172190
lambda x: "\n".join(x.values)).reset_index()
173-
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
191+
if self.drop_duplicates:
192+
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
174193

175194
if self.split_btags:
176195
df["text"] = df["text"].replace(re.compile(r"(\</?b\>)"), r" \1 ")
177196
df = df.replace(re.compile(r"(xxref|xxanchor)-[\w\d-]*"), "\\1 ")
178-
df = df.replace(re.compile(r"(^|[ ])\d+\.\d+(\b|%)"), " xxnum ")
179-
df = df.replace(re.compile(r"(^|[ ])\d+(\b|%)"), " xxnum ")
197+
if self.remove_num:
198+
df = df.replace(re.compile(r"(^|[ ])\d+\.\d+(\b|%)"), " xxnum ")
199+
df = df.replace(re.compile(r"(^|[ ])\d+(\b|%)"), " xxnum ")
180200
df = df.replace(re.compile(r"\bdata set\b"), " dataset ")
181201
df["label"] = df["cell_type"].apply(lambda x: label_map.get(x, 0))
182202
df["label"] = pd.Categorical(df["label"])
@@ -193,6 +213,7 @@ def _set_results(self, prefix, preds, true_y):
193213
r = {}
194214
r[f"{prefix}_accuracy"] = m["accuracy"]
195215
r[f"{prefix}_precision"] = m["precision"]
216+
r[f"{prefix}_recall"] = m["recall"]
196217
r[f"{prefix}_cm"] = confusion_matrix(true_y, preds).tolist()
197218
self.update_results(**r)
198219

@@ -214,26 +235,29 @@ def evaluate(self, model, train_df, valid_df, test_df):
214235
true_y = tdf["label"]
215236
self._set_results(prefix, preds, true_y)
216237

217-
def show_results(self, *ds):
238+
def show_results(self, *ds, normalize=True):
218239
if not len(ds):
219240
ds = ["train", "valid", "test"]
220241
for prefix in ds:
221242
print(f"{prefix} dataset")
222-
print(f" * accuracy: {self.results[f'{prefix}_accuracy']}")
223-
print(f" * precision: {self.results[f'{prefix}_precision']}")
224-
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm']), normalize=True)
243+
print(f" * accuracy: {self.results[f'{prefix}_accuracy']:.3f}")
244+
print(f" * μ-precision: {self.results[f'{prefix}_precision']:.3f}")
245+
print(f" * μ-recall: {self.results[f'{prefix}_recall']:.3f}")
246+
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm']), normalize=normalize)
225247

226-
def _plot_confusion_matrix(self, cm, normalize):
248+
def _plot_confusion_matrix(self, cm, normalize, fmt=None):
227249
if normalize:
228250
cm = cm / cm.sum(axis=1)[:, None]
229-
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)"]
251+
if fmt is None:
252+
fmt = "0.2f" if normalize else "d"
253+
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC"]
230254
df_cm = pd.DataFrame(cm, index=[i for i in target_names],
231255
columns=[i for i in target_names])
232256
plt.figure(figsize=(10, 10))
233257
ax = sn.heatmap(df_cm,
234258
annot=True,
235259
square=True,
236-
fmt="0.2f" if normalize else "d",
260+
fmt=fmt,
237261
cmap="YlGnBu",
238262
mask=cm == 0,
239263
linecolor="black",

0 commit comments

Comments
 (0)