@@ -16,13 +16,15 @@ class Labels(Enum):
16
16
DATASET = 1
17
17
PAPER_MODEL = 2
18
18
COMPETING_MODEL = 3
19
+ METRIC = 4
19
20
20
21
label_map = {
21
22
"dataset" : Labels .DATASET .value ,
22
23
"dataset-sub" : Labels .DATASET .value ,
23
24
"model-paper" : Labels .PAPER_MODEL .value ,
24
25
"model-best" : Labels .PAPER_MODEL .value ,
25
- "model-competing" : Labels .COMPETING_MODEL .value
26
+ "model-competing" : Labels .COMPETING_MODEL .value ,
27
+ "dataset-metric" : Labels .METRIC .value
26
28
}
27
29
28
30
# put here to avoid recompiling, used only in _limit_context
@@ -43,6 +45,9 @@ class Experiment:
43
45
context_tokens : int = None # max. number of words before <b> and after </b>
44
46
analyzer : str = "word" # "char", "word" or "char_wb"
45
47
lowercase : bool = True
48
+ remove_num : bool = True
49
+ drop_duplicates : bool = True
50
+ mark_this_paper : bool = False
46
51
47
52
class_weight : str = None
48
53
multinomial_type : str = "manual" # "manual", "ovr", "multinomial"
@@ -142,6 +147,8 @@ def _limit_context(self, text):
142
147
def _transform_df (self , df ):
143
148
if self .merge_type not in ["concat" , "vote_maj" , "vote_avg" , "vote_max" ]:
144
149
raise Exception (f"merge_type must be one of concat, vote_maj, vote_avg, vote_max, but { self .merge_type } was given" )
150
+ if self .mark_this_paper and (self .merge_type != "concat" or self .this_paper ):
151
+ raise Exception ("merge_type must be 'concat' and this_paper must be false" )
145
152
#df = df[df["cell_type"] != "table-meta"] # otherwise we get precision 0 on test set
146
153
if self .evidence_limit is not None :
147
154
df = df .groupby (by = ["ext_id" , "this_paper" ]).head (self .evidence_limit )
@@ -154,14 +161,25 @@ def _transform_df(self, df):
154
161
df ["text" ] = df [self .evidence_source ].replace (re .compile ("<b>.*?</b>" ), " xxmask " )
155
162
else :
156
163
df ["text" ] = df [self .evidence_source ]
157
-
158
164
elif self .mask :
159
165
raise Exception ("Masking with evidence_source='text' makes no sense" )
160
- if not self .fixed_this_paper :
166
+
167
+ if self .mark_this_paper :
168
+ df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" , "this_paper" ]).text .apply (
169
+ lambda x : "\n " .join (x .values )).reset_index ()
170
+ this_paper_map = {
171
+ True : "this paper" ,
172
+ False : "other paper"
173
+ }
174
+ df .text = "xxfld 3 " + df .this_paper .apply (this_paper_map .get ) + " " + df .text
175
+ df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" ]).text .apply (
176
+ lambda x : " " .join (x .values )).reset_index ()
177
+ elif not self .fixed_this_paper :
161
178
if self .merge_fragments and self .merge_type == "concat" :
162
179
df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" , "this_paper" ]).text .apply (
163
180
lambda x : "\n " .join (x .values )).reset_index ()
164
- df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
181
+ if self .drop_duplicates :
182
+ df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
165
183
if self .this_paper :
166
184
df = df [df .this_paper ]
167
185
else :
@@ -170,13 +188,15 @@ def _transform_df(self, df):
170
188
if self .merge_fragments and self .merge_type == "concat" :
171
189
df = df .groupby (by = ["ext_id" , "cell_content" , "cell_type" ]).text .apply (
172
190
lambda x : "\n " .join (x .values )).reset_index ()
173
- df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
191
+ if self .drop_duplicates :
192
+ df = df .drop_duplicates (["text" , "cell_content" , "cell_type" ]).fillna ("" )
174
193
175
194
if self .split_btags :
176
195
df ["text" ] = df ["text" ].replace (re .compile (r"(\</?b\>)" ), r" \1 " )
177
196
df = df .replace (re .compile (r"(xxref|xxanchor)-[\w\d-]*" ), "\\ 1 " )
178
- df = df .replace (re .compile (r"(^|[ ])\d+\.\d+(\b|%)" ), " xxnum " )
179
- df = df .replace (re .compile (r"(^|[ ])\d+(\b|%)" ), " xxnum " )
197
+ if self .remove_num :
198
+ df = df .replace (re .compile (r"(^|[ ])\d+\.\d+(\b|%)" ), " xxnum " )
199
+ df = df .replace (re .compile (r"(^|[ ])\d+(\b|%)" ), " xxnum " )
180
200
df = df .replace (re .compile (r"\bdata set\b" ), " dataset " )
181
201
df ["label" ] = df ["cell_type" ].apply (lambda x : label_map .get (x , 0 ))
182
202
df ["label" ] = pd .Categorical (df ["label" ])
@@ -193,6 +213,7 @@ def _set_results(self, prefix, preds, true_y):
193
213
r = {}
194
214
r [f"{ prefix } _accuracy" ] = m ["accuracy" ]
195
215
r [f"{ prefix } _precision" ] = m ["precision" ]
216
+ r [f"{ prefix } _recall" ] = m ["recall" ]
196
217
r [f"{ prefix } _cm" ] = confusion_matrix (true_y , preds ).tolist ()
197
218
self .update_results (** r )
198
219
@@ -214,26 +235,29 @@ def evaluate(self, model, train_df, valid_df, test_df):
214
235
true_y = tdf ["label" ]
215
236
self ._set_results (prefix , preds , true_y )
216
237
217
- def show_results (self , * ds ):
238
+ def show_results (self , * ds , normalize = True ):
218
239
if not len (ds ):
219
240
ds = ["train" , "valid" , "test" ]
220
241
for prefix in ds :
221
242
print (f"{ prefix } dataset" )
222
- print (f" * accuracy: { self .results [f'{ prefix } _accuracy' ]} " )
223
- print (f" * precision: { self .results [f'{ prefix } _precision' ]} " )
224
- self ._plot_confusion_matrix (np .array (self .results [f'{ prefix } _cm' ]), normalize = True )
243
+ print (f" * accuracy: { self .results [f'{ prefix } _accuracy' ]:.3f} " )
244
+ print (f" * μ-precision: { self .results [f'{ prefix } _precision' ]:.3f} " )
245
+ print (f" * μ-recall: { self .results [f'{ prefix } _recall' ]:.3f} " )
246
+ self ._plot_confusion_matrix (np .array (self .results [f'{ prefix } _cm' ]), normalize = normalize )
225
247
226
- def _plot_confusion_matrix (self , cm , normalize ):
248
+ def _plot_confusion_matrix (self , cm , normalize , fmt = None ):
227
249
if normalize :
228
250
cm = cm / cm .sum (axis = 1 )[:, None ]
229
- target_names = ["OTHER" , "DATASET" , "MODEL (paper)" , "MODEL (comp.)" ]
251
+ if fmt is None :
252
+ fmt = "0.2f" if normalize else "d"
253
+ target_names = ["OTHER" , "DATASET" , "MODEL (paper)" , "MODEL (comp.)" , "METRIC" ]
230
254
df_cm = pd .DataFrame (cm , index = [i for i in target_names ],
231
255
columns = [i for i in target_names ])
232
256
plt .figure (figsize = (10 , 10 ))
233
257
ax = sn .heatmap (df_cm ,
234
258
annot = True ,
235
259
square = True ,
236
- fmt = "0.2f" if normalize else "d" ,
260
+ fmt = fmt ,
237
261
cmap = "YlGnBu" ,
238
262
mask = cm == 0 ,
239
263
linecolor = "black" ,
0 commit comments