Skip to content

Commit 584276e

Browse files
author
Marcin Kardas
committed
Add voting strategies to Experiment
1 parent 5a9255a commit 584276e

File tree

2 files changed

+92
-23
lines changed

2 files changed

+92
-23
lines changed

sota_extractor2/models/structure/experiment.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,22 @@ class Labels(Enum):
2525
"model-competing": Labels.COMPETING_MODEL.value
2626
}
2727

28+
# put here to avoid recompiling, used only in _limit_context
29+
elastic_tag_split_re = re.compile("(<b>.*?</b>)")
30+
2831
@dataclass
2932
class Experiment:
3033
vectorizer: str = "tfidf"
3134
this_paper: bool = False
3235
merge_fragments: bool = False
36+
merge_type: str = "concat" # "concat", "vote_maj", "vote_avg", "vote_max"
3337
evidence_source: str = "text" # "text" or "text_highlited"
3438
split_btags: bool = False # <b>Test</b> -> <b> Test </b>
35-
fixed_tokenizer: bool = False # <b> and </b> are not split
39+
fixed_tokenizer: bool = False # if True, <b> and </b> are not split into < b > and < / b >
40+
fixed_this_paper: bool = False # if True and this_paper, filter this_paper before merging fragments
41+
mask: bool = False # if True and evidence_source = "text_highlited", replace <b>...</b> with xxmask
42+
evidence_limit: int = None # maximum number of evidences per cell (grouped by (ext_id, this_paper))
43+
context_tokens: int = None # max. number of words before <b> and after </b>
3644

3745
class_weight: str = None
3846
multinomial_type: str = "manual" # "manual", "ovr", "multinomial"
@@ -107,17 +115,61 @@ def get_trained_model(self, train_df):
107115
self.has_model = True
108116
return nbsvm
109117

118+
def _limit_context(self, text):
119+
parts = elastic_tag_split_re.split(text)
120+
new_parts = []
121+
end = len(parts)
122+
for i, part in enumerate(parts):
123+
if i % 2 == 0:
124+
toks = tokenize(part)
125+
if i == 0:
126+
toks = toks[-self.context_tokens:]
127+
elif i == end:
128+
toks = toks[:self.context_tokens]
129+
else:
130+
j = len(toks) - 2 * self.context_tokens
131+
if j > 0:
132+
toks = toks[:self.context_tokens] + toks[-self.context_tokens:]
133+
new_parts.append(' '.join(toks))
134+
else:
135+
new_parts.append(part)
136+
return ' '.join(new_parts)
137+
138+
139+
110140
def _transform_df(self, df):
141+
if self.merge_type not in ["concat", "vote_maj", "vote_avg", "vote_max"]:
142+
raise Exception(f"merge_type must be one of concat, vote_maj, vote_avg, vote_max, but {self.merge_type} was given")
111143
df = df[df["cell_type"] != "table-meta"] # otherwise we get precision 0 on test set
144+
if self.evidence_limit is not None:
145+
df = df.groupby(by=["ext_id", "this_paper"]).head(self.evidence_limit)
146+
if self.context_tokens is not None:
147+
df.loc["text_highlited"] = df["text_highlited"].apply(self._limit_context)
148+
df.loc["text"] = df["text_highlited"].str.replace("<b>", " ").replace("</b>", " ")
112149
if self.evidence_source != "text":
113150
df = df.copy(True)
114-
df["text"] = df[self.evidence_source]
115-
if self.merge_fragments:
116-
df = df.groupby(by=["ext_id", "cell_content", "cell_type", "this_paper"]).text.apply(
117-
lambda x: "\n".join(x.values)).reset_index()
118-
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
119-
if self.this_paper:
120-
df = df[df.this_paper]
151+
if self.mask:
152+
df["text"] = df[self.evidence_source].replace(re.compile("<b>.*?</b>"), " xxmask ")
153+
else:
154+
df["text"] = df[self.evidence_source]
155+
156+
elif self.mask:
157+
raise Exception("Masking with evidence_source='text' makes no sense")
158+
if not self.fixed_this_paper:
159+
if self.merge_fragments and self.merge_type == "concat":
160+
df = df.groupby(by=["ext_id", "cell_content", "cell_type", "this_paper"]).text.apply(
161+
lambda x: "\n".join(x.values)).reset_index()
162+
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
163+
if self.this_paper:
164+
df = df[df.this_paper]
165+
else:
166+
if self.this_paper:
167+
df = df[df.this_paper]
168+
if self.merge_fragments and self.merge_type == "concat":
169+
df = df.groupby(by=["ext_id", "cell_content", "cell_type"]).text.apply(
170+
lambda x: "\n".join(x.values)).reset_index()
171+
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
172+
121173
if self.split_btags:
122174
df["text"] = df["text"].replace(re.compile(r"(\</?b\>)"), r" \1 ")
123175
df = df.replace(re.compile(r"(xxref|xxanchor)-[\w\d-]*"), "\\1 ")
@@ -135,9 +187,20 @@ def evaluate(self, model, train_df, valid_df, test_df):
135187
for prefix, tdf in zip(["train", "valid", "test"], [train_df, valid_df, test_df]):
136188
probs = model.predict_proba(tdf["text"])
137189
preds = np.argmax(probs, axis=1)
138-
true_y = tdf["label"]
139190

140-
m = metrics(preds, tdf.label)
191+
if self.merge_fragments and self.merge_type != "concat":
192+
if self.merge_type == "vote_maj":
193+
vote_results = preds_for_cell_content(tdf, probs)
194+
elif self.merge_type == "vote_avg":
195+
vote_results = preds_for_cell_content_multi(tdf, probs)
196+
elif self.merge_type == "vote_max":
197+
vote_results = preds_for_cell_content_max(tdf, probs)
198+
preds = vote_results["pred"]
199+
true_y = vote_results["true"]
200+
else:
201+
true_y = tdf["label"]
202+
203+
m = metrics(preds, true_y)
141204
r = {}
142205
r[f"{prefix}_accuracy"] = m["accuracy"]
143206
r[f"{prefix}_precision"] = m["precision"]

sota_extractor2/models/structure/nbsvm.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,21 @@ def get_number_of_classes(y):
3838
else:
3939
return y.shape[1]
4040

41+
re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')
42+
re_tok_fixed = re.compile(
43+
f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])'.replace('<', '').replace('>', '').replace('/', ''))
44+
45+
def tokenize(s):
46+
return re_tok.sub(r' \1 ', s).split()
47+
48+
def tokenize_fixed(s):
49+
return re_tok_fixed.sub(r' \1 ', s).split()
50+
51+
4152
class NBSVM:
4253
def __init__(self, experiment):
4354
self.experiment = experiment
4455

45-
re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')
46-
re_tok_fixed = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])'.replace('<', '').replace('>', '').replace('/', ''))
47-
48-
def tokenize(self, s):
49-
return self.re_tok.sub(r' \1 ', s).split()
50-
51-
def tokenize_fixed(self, s):
52-
return self.re_tok_fixed.sub(r' \1 ', s).split()
5356

5457
def pr(self, y_i, y):
5558
p = self.trn_term_doc[y == y_i].sum(0)
@@ -67,14 +70,15 @@ def get_mdl(self, y):
6770
def bow(self, X_train):
6871
self.n = X_train.shape[0]
6972

73+
tokenizer = tokenize_fixed if self.experiment.fixed_tokenizer else tokenize
7074
if self.experiment.vectorizer == "tfidf":
7175
self.vec = TfidfVectorizer(ngram_range=self.experiment.ngram_range,
72-
tokenizer=self.tokenize_fixed if self.experiment.fixed_tokenizer else self.tokenize,
76+
tokenizer=tokenizer,
7377
min_df=self.experiment.min_df, max_df=self.experiment.max_df,
7478
strip_accents='unicode', use_idf=1,
7579
smooth_idf=1, sublinear_tf=1)
7680
elif self.experiment.vectorizer == "count":
77-
self.vec = CountVectorizer(ngram_range=self.experiment.ngram_range, tokenizer=self.tokenize,
81+
self.vec = CountVectorizer(ngram_range=self.experiment.ngram_range, tokenizer=tokenizer,
7882
min_df=self.experiment.min_df, max_df=self.experiment.max_df,
7983
strip_accents='unicode')
8084
else:
@@ -122,7 +126,7 @@ def sort_features_by_importance(self, label):
122126
names = np.array(self.vec.get_feature_names())
123127
if self.experiment.multinomial_type == "manual":
124128
m, r = self.models[label]
125-
f = m.coef_[0] * np.array(r[0])
129+
f = m.coef_[0] * np.array(r)[0]
126130
elif self.experiment.multinomial_type == "multinomial":
127131
f = self.models[0].coef_[label]
128132
else:
@@ -133,6 +137,8 @@ def sort_features_by_importance(self, label):
133137
return names[indices], f[indices]
134138

135139
def get_mismatched(self, df, true_label, predicted_label):
140+
if self.experiment.merge_fragments and self.experiment.merge_type != "concat":
141+
print("warning: the returned results are before merging")
136142
true_label = true_label.value
137143
predicted_label = predicted_label.value
138144

@@ -194,12 +200,12 @@ def preds_for_cell_content_multi(test_df, probs, group_by=["cell_content"]):
194200
'counts': grouped_counts})
195201
return results
196202

197-
def preds_for_cell_content_best(test_df, probs, group_by=["cell_content"]):
203+
def preds_for_cell_content_max(test_df, probs, group_by=["cell_content"]):
198204
test_df = test_df.copy()
199205
probs_df = pd.DataFrame(probs, index=test_df.index)
200206
test_df = pd.concat([test_df, probs_df], axis=1)
201207
grouped_preds = np.argmax(test_df.groupby(
202-
group_by)[probs_df.columns].sum().values, axis=1)
208+
group_by)[probs_df.columns].max().values, axis=1)
203209
grouped_counts = test_df.groupby(group_by)["label"].count()
204210
results = pd.DataFrame({'true': test_df.groupby(group_by)["label"].agg(lambda x: x.value_counts().index[0]),
205211
'pred': grouped_preds,

0 commit comments

Comments
 (0)