@@ -25,14 +25,22 @@ class Labels(Enum):
25
25
"model-competing" : Labels .COMPETING_MODEL .value
26
26
}
27
27
28
+ # put here to avoid recompiling, used only in _limit_context
29
+ elastic_tag_split_re = re .compile ("(<b>.*?</b>)" )
30
+
28
31
@dataclass
29
32
class Experiment :
30
33
vectorizer : str = "tfidf"
31
34
this_paper : bool = False
32
35
merge_fragments : bool = False
36
+ merge_type : str = "concat" # "concat", "vote_maj", "vote_avg", "vote_max"
33
37
evidence_source : str = "text" # "text" or "text_highlited"
34
38
split_btags : bool = False # <b>Test</b> -> <b> Test </b>
35
- fixed_tokenizer : bool = False # <b> and </b> are not split
39
+ fixed_tokenizer : bool = False # if True, <b> and </b> are not split into < b > and < / b >
40
+ fixed_this_paper : bool = False # if True and this_paper, filter this_paper before merging fragments
41
+ mask : bool = False # if True and evidence_source = "text_highlited", replace <b>...</b> with xxmask
42
+ evidence_limit : int = None # maximum number of evidences per cell (grouped by (ext_id, this_paper))
43
+ context_tokens : int = None # max. number of words before <b> and after </b>
36
44
37
45
class_weight : str = None
38
46
multinomial_type : str = "manual" # "manual", "ovr", "multinomial"
@@ -107,17 +115,61 @@ def get_trained_model(self, train_df):
107
115
self .has_model = True
108
116
return nbsvm
109
117
118
+ def _limit_context (self , text ):
119
+ parts = elastic_tag_split_re .split (text )
120
+ new_parts = []
121
+ end = len (parts )
122
+ for i , part in enumerate (parts ):
123
+ if i % 2 == 0 :
124
+ toks = tokenize (part )
125
+ if i == 0 :
126
+ toks = toks [- self .context_tokens :]
127
+ elif i == end :
128
+ toks = toks [:self .context_tokens ]
129
+ else :
130
+ j = len (toks ) - 2 * self .context_tokens
131
+ if j > 0 :
132
+ toks = toks [:self .context_tokens ] + toks [- self .context_tokens :]
133
+ new_parts .append (' ' .join (toks ))
134
+ else :
135
+ new_parts .append (part )
136
+ return ' ' .join (new_parts )
137
+
138
+
139
+
110
140
def _transform_df (self , df ):
141
+ if self .merge_type not in ["concat" , "vote_maj" , "vote_avg" , "vote_max" ]:
142
+ raise Exception (f"merge_type must be one of concat, vote_maj, vote_avg, vote_max, but { self .merge_type } was given" )
111
143
df = df [df ["cell_type" ] != "table-meta" ] # otherwise we get precision 0 on test set
144
+ if self .evidence_limit is not None :
145
+ df = df .groupby (by = ["ext_id" , "this_paper" ]).head (self .evidence_limit )
146
+ if self .context_tokens is not None :
147
+ df .loc ["text_highlited" ] = df ["text_highlited" ].apply (self ._limit_context )
148
+ df .loc ["text" ] = df ["text_highlited" ].str .replace ("<b>" , " " ).replace ("</b>" , " " )
112
149
if self .evidence_source != "text" :
113
150
df = df .copy (True )
114
- df ["text" ] = df [self .evidence_source ]
115
- if self .merge_fragments :
116
- df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" , "this_paper" ]).text .apply (
117
- lambda x : "\n " .join (x .values )).reset_index ()
118
- df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
119
- if self .this_paper :
120
- df = df [df .this_paper ]
151
+ if self .mask :
152
+ df ["text" ] = df [self .evidence_source ].replace (re .compile ("<b>.*?</b>" ), " xxmask " )
153
+ else :
154
+ df ["text" ] = df [self .evidence_source ]
155
+
156
+ elif self .mask :
157
+ raise Exception ("Masking with evidence_source='text' makes no sense" )
158
+ if not self .fixed_this_paper :
159
+ if self .merge_fragments and self .merge_type == "concat" :
160
+ df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" , "this_paper" ]).text .apply (
161
+ lambda x : "\n " .join (x .values )).reset_index ()
162
+ df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
163
+ if self .this_paper :
164
+ df = df [df .this_paper ]
165
+ else :
166
+ if self .this_paper :
167
+ df = df [df .this_paper ]
168
+ if self .merge_fragments and self .merge_type == "concat" :
169
+ df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" ]).text .apply (
170
+ lambda x : "\n " .join (x .values )).reset_index ()
171
+ df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
172
+
121
173
if self .split_btags :
122
174
df ["text" ] = df ["text" ].replace (re .compile (r"(\</?b\>)" ), r" \1 " )
123
175
df = df .replace (re .compile (r"(xxref|xxanchor)-[\w\d-]*" ), "\\ 1 " )
@@ -135,9 +187,20 @@ def evaluate(self, model, train_df, valid_df, test_df):
135
187
for prefix , tdf in zip (["train" , "valid" , "test" ], [train_df , valid_df , test_df ]):
136
188
probs = model .predict_proba (tdf ["text" ])
137
189
preds = np .argmax (probs , axis = 1 )
138
- true_y = tdf ["label" ]
139
190
140
- m = metrics (preds , tdf .label )
191
+ if self .merge_fragments and self .merge_type != "concat" :
192
+ if self .merge_type == "vote_maj" :
193
+ vote_results = preds_for_cell_content (tdf , probs )
194
+ elif self .merge_type == "vote_avg" :
195
+ vote_results = preds_for_cell_content_multi (tdf , probs )
196
+ elif self .merge_type == "vote_max" :
197
+ vote_results = preds_for_cell_content_max (tdf , probs )
198
+ preds = vote_results ["pred" ]
199
+ true_y = vote_results ["true" ]
200
+ else :
201
+ true_y = tdf ["label" ]
202
+
203
+ m = metrics (preds , true_y )
141
204
r = {}
142
205
r [f"{ prefix } _accuracy" ] = m ["accuracy" ]
143
206
r [f"{ prefix } _precision" ] = m ["precision" ]
0 commit comments