Skip to content

Commit 1db8c1f

Browse files
author
Marcin Kardas
committed
Tune searching
1 parent 584276e commit 1db8c1f

File tree

4 files changed

+75
-32
lines changed

4 files changed

+75
-32
lines changed

sota_extractor2/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# otherwise use this files
1212
data = Path("/mnt/efs/pwc/data")
13-
goldtags_dump = data / "dumps" / "goldtags-2019.06.28_0916.json.gz"
13+
goldtags_dump = data / "dumps" / "goldtags-2019.07.16_2214.json.gz"
1414

1515

1616
elastic = dict(hosts=['localhost'], timeout=20)

sota_extractor2/data/structure.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,23 @@ def consume_cells(*matrix):
2020
yield Cell(row=row_id, col=col_id, vals=cell_val)
2121

2222

23-
def fetch_evidence(cell_content, paper_id, paper_limit=10, corpus_limit=10):
23+
reference_re = re.compile(r"\[[^]]*\]")
24+
ours_re = re.compile(r"\(ours?\)")
25+
all_parens_re = re.compile(r"\([^)]*\)")
26+
27+
28+
def clear_cell(s):
29+
for pat in [reference_re, all_parens_re]:
30+
s = pat.sub("", s)
31+
s = s.strip()
32+
return s
33+
34+
35+
def fetch_evidence(cell_content, cell_reference, paper_id, paper_limit=10, corpus_limit=10):
36+
cell_content = clear_cell(cell_content)
37+
if cell_content == "" and cell_reference == "":
38+
return []
39+
2440
evidence_query = Fragment.search().highlight(
2541
'text', pre_tags="<b>", post_tags="</b>", fragment_size=400)
2642
cell_content = cell_content.replace("\xa0", " ")
@@ -31,10 +47,21 @@ def fetch_evidence(cell_content, paper_id, paper_limit=10, corpus_limit=10):
3147
paper_fragments = list(evidence_query
3248
.filter('term', paper_id=paper_id)
3349
.query('match_phrase', text=query)[:paper_limit])
50+
if cell_reference != "":
51+
reference_fragments = list(evidence_query
52+
.filter('term', paper_id=paper_id)
53+
.query('match_phrase', text={
54+
"query": cell_reference,
55+
"slop": 1
56+
})[:paper_limit])
57+
else:
58+
reference_fragments = []
3459
other_fagements = list(evidence_query
3560
.exclude('term', paper_id=paper_id)
3661
.query('match_phrase', text=query)[:corpus_limit])
37-
return paper_fragments + other_fagements
62+
if not len(paper_fragments) and not len(reference_fragments) and not len(other_fagements):
63+
print(f"No evidences for '{cell_content}' of {paper_id}")
64+
return paper_fragments + reference_fragments + other_fagements
3865

3966
fix_refs_re = re.compile('\(\?\)|\s[?]+(\s|$)')
4067

@@ -44,29 +71,34 @@ def fix_refs(text):
4471

4572

4673
highlight_re = re.compile("</?b>")
74+
partial_highlight_re = re.compile(r"\<b\>xxref\</b\>-(?!\<b\>)")
75+
76+
77+
def fix_reference_hightlight(s):
78+
return partial_highlight_re.sub("xxref-", s)
4779

4880

4981
def create_evidence_records(textfrag, cell, table):
5082
for text_highlited in textfrag.meta['highlight']['text']:
51-
text_highlited = fix_refs(text_highlited)
83+
text_highlited = fix_reference_hightlight(fix_refs(text_highlited))
5284
text = highlight_re.sub("", text_highlited)
5385
text_sha1 = hashlib.sha1(text.encode("utf-8")).hexdigest()
5486

5587
cell_ext_id = f"{table.ext_id}/{cell.row}/{cell.col}"
5688

57-
if len(text.split()) > 50:
58-
yield {"text_sha1": text_sha1,
59-
"text_highlited": text_highlited,
60-
"text": text,
61-
"header": textfrag.header,
62-
"cell_type": cell.vals[1],
63-
"cell_content": fix_refs(cell.vals[0]),
64-
"this_paper": textfrag.paper_id == table.paper_id,
65-
"row": cell.row,
66-
"col": cell.col,
67-
"ext_id": cell_ext_id
68-
#"table_id":table_id
69-
}
89+
yield {"text_sha1": text_sha1,
90+
"text_highlited": text_highlited,
91+
"text": text,
92+
"header": textfrag.header,
93+
"cell_type": cell.vals[1],
94+
"cell_content": fix_refs(cell.vals[0]),
95+
"cell_reference": cell.vals[2],
96+
"this_paper": textfrag.paper_id == table.paper_id,
97+
"row": cell.row,
98+
"col": cell.col,
99+
"ext_id": cell_ext_id
100+
#"table_id":table_id
101+
}
70102

71103

72104
def filter_cells(cell):
@@ -83,8 +115,8 @@ def get_limits(cell_type):
83115
return dict(paper_limit=paper_limit, corpus_limit=corpus_limit)
84116
records = [
85117
record
86-
for cell in consume_cells(table.matrix, table.matrix_gold_tags) if filter_cells(cell)
87-
for evidence in fetch_evidence(cell.vals[0], paper_id=table.paper_id, **get_limits(cell.vals[1]))
118+
for cell in consume_cells(table.matrix, table.matrix_gold_tags, table.matrix_references) if filter_cells(cell)
119+
for evidence in fetch_evidence(cell.vals[0], cell.vals[2], paper_id=table.paper_id, **get_limits(cell.vals[1]))
88120
for record in create_evidence_records(evidence, cell, table=table)
89121
]
90122
df = pd.DataFrame.from_records(records)

sota_extractor2/models/structure/experiment.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class Experiment:
4141
mask: bool = False # if True and evidence_source = "text_highlited", replace <b>...</b> with xxmask
4242
evidence_limit: int = None # maximum number of evidences per cell (grouped by (ext_id, this_paper))
4343
context_tokens: int = None # max. number of words before <b> and after </b>
44+
analyzer: str = "word" # "char", "word" or "char_wb"
45+
lowercase: bool = True
4446

4547
class_weight: str = None
4648
multinomial_type: str = "manual" # "manual", "ovr", "multinomial"
@@ -140,7 +142,7 @@ def _limit_context(self, text):
140142
def _transform_df(self, df):
141143
if self.merge_type not in ["concat", "vote_maj", "vote_avg", "vote_max"]:
142144
raise Exception(f"merge_type must be one of concat, vote_maj, vote_avg, vote_max, but {self.merge_type} was given")
143-
df = df[df["cell_type"] != "table-meta"] # otherwise we get precision 0 on test set
145+
#df = df[df["cell_type"] != "table-meta"] # otherwise we get precision 0 on test set
144146
if self.evidence_limit is not None:
145147
df = df.groupby(by=["ext_id", "this_paper"]).head(self.evidence_limit)
146148
if self.context_tokens is not None:
@@ -181,7 +183,18 @@ def _transform_df(self, df):
181183
return df
182184

183185
def transform_df(self, *dfs):
184-
return [self._transform_df(df) for df in dfs]
186+
transformed = [self._transform_df(df) for df in dfs]
187+
if len(transformed) == 1:
188+
return transformed[0]
189+
return transformed
190+
191+
def _set_results(self, prefix, preds, true_y):
192+
m = metrics(preds, true_y)
193+
r = {}
194+
r[f"{prefix}_accuracy"] = m["accuracy"]
195+
r[f"{prefix}_precision"] = m["precision"]
196+
r[f"{prefix}_cm"] = confusion_matrix(true_y, preds).tolist()
197+
self.update_results(**r)
185198

186199
def evaluate(self, model, train_df, valid_df, test_df):
187200
for prefix, tdf in zip(["train", "valid", "test"], [train_df, valid_df, test_df]):
@@ -199,13 +212,7 @@ def evaluate(self, model, train_df, valid_df, test_df):
199212
true_y = vote_results["true"]
200213
else:
201214
true_y = tdf["label"]
202-
203-
m = metrics(preds, true_y)
204-
r = {}
205-
r[f"{prefix}_accuracy"] = m["accuracy"]
206-
r[f"{prefix}_precision"] = m["precision"]
207-
r[f"{prefix}_cm"] = confusion_matrix(true_y, preds).tolist()
208-
self.update_results(**r)
215+
self._set_results(prefix, preds, true_y)
209216

210217
def show_results(self, *ds):
211218
if not len(ds):

sota_extractor2/models/structure/nbsvm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_mdl(self, y):
6363
r = np.log(self.pr(1, y) / self.pr(0, y))
6464
m = LogisticRegression(C=self.experiment.C, penalty=self.experiment.penalty,
6565
dual=self.experiment.dual, solver=self.experiment.solver,
66-
max_iter=self.experiment.max_iter)
66+
max_iter=self.experiment.max_iter, class_weight=self.experiment.class_weight)
6767
x_nb = self.trn_term_doc.multiply(r)
6868
return m.fit(x_nb, y), r
6969

@@ -74,11 +74,15 @@ def bow(self, X_train):
7474
if self.experiment.vectorizer == "tfidf":
7575
self.vec = TfidfVectorizer(ngram_range=self.experiment.ngram_range,
7676
tokenizer=tokenizer,
77+
lowercase=self.experiment.lowercase,
78+
analyzer=self.experiment.analyzer,
7779
min_df=self.experiment.min_df, max_df=self.experiment.max_df,
7880
strip_accents='unicode', use_idf=1,
7981
smooth_idf=1, sublinear_tf=1)
8082
elif self.experiment.vectorizer == "count":
8183
self.vec = CountVectorizer(ngram_range=self.experiment.ngram_range, tokenizer=tokenizer,
84+
analyzer=self.experiment.analyzer,
85+
lowercase=self.experiment.lowercase,
8286
min_df=self.experiment.min_df, max_df=self.experiment.max_df,
8387
strip_accents='unicode')
8488
else:
@@ -93,11 +97,11 @@ def train_models(self, y_train):
9397
#print('fit', i)
9498
m, r = self.get_mdl(get_class_column(y_train, i))
9599
self.models.append((m, r))
96-
elif self.experiment.multinomial_type == "multinomial":
100+
elif self.experiment.multinomial_type == "multinomial" or self.experiment.multinomial_type == "ovr":
97101
m = LogisticRegression(C=self.experiment.C, penalty=self.experiment.penalty,
98102
dual=self.experiment.dual, solver=self.experiment.solver,
99103
max_iter=self.experiment.max_iter,
100-
multi_class="multinomial", class_weight=self.experiment.class_weight)
104+
multi_class=self.experiment.multinomial_type, class_weight=self.experiment.class_weight)
101105
x_nb = self.trn_term_doc
102106
self.models.append(m.fit(x_nb, y_train))
103107
else:
@@ -115,7 +119,7 @@ def predict_proba(self, X_test):
115119
for i in range(0, self.c):
116120
m, r = self.models[i]
117121
preds[:, i] = m.predict_proba(test_term_doc.multiply(r))[:, 1]
118-
elif self.experiment.multinomial_type == "multinomial":
122+
elif self.experiment.multinomial_type == "multinomial" or self.experiment.multinomial_type == "ovr":
119123
preds = self.models[0].predict_proba(test_term_doc)
120124
else:
121125
raise Exception(f"Unsupported multinomial_type {self.experiment.multinomial_type}")

0 commit comments

Comments
 (0)