@@ -20,15 +20,40 @@ class Labels(Enum):
20
20
EMPTY = 5
21
21
22
22
23
+ class LabelsExt (Enum ):
24
+ OTHER = 0
25
+ PARAMS = 6
26
+ TASK = 7
27
+ DATASET = 1
28
+ SUBDATASET = 8
29
+ PAPER_MODEL = 2
30
+ BEST_MODEL = 9
31
+ ENSEMBLE_MODEL = 10
32
+ COMPETING_MODEL = 3
33
+ METRIC = 4
34
+ EMPTY = 5
35
+
36
+
23
37
label_map = {
24
38
"dataset" : Labels .DATASET .value ,
25
39
"dataset-sub" : Labels .DATASET .value ,
26
40
"model-paper" : Labels .PAPER_MODEL .value ,
27
41
"model-best" : Labels .PAPER_MODEL .value ,
28
42
"model-ensemble" : Labels .PAPER_MODEL .value ,
29
43
"model-competing" : Labels .COMPETING_MODEL .value ,
30
- "dataset-metric" : Labels .METRIC .value ,
31
- # "model-params": Labels.PARAMS.value
44
+ "dataset-metric" : Labels .METRIC .value
45
+ }
46
+
47
+ label_map_ext = {
48
+ "dataset" : LabelsExt .DATASET .value ,
49
+ "dataset-sub" : LabelsExt .SUBDATASET .value ,
50
+ "model-paper" : LabelsExt .PAPER_MODEL .value ,
51
+ "model-best" : LabelsExt .BEST_MODEL .value ,
52
+ "model-ensemble" : LabelsExt .ENSEMBLE_MODEL .value ,
53
+ "model-competing" : LabelsExt .COMPETING_MODEL .value ,
54
+ "dataset-metric" : LabelsExt .METRIC .value ,
55
+ "model-params" : LabelsExt .PARAMS .value ,
56
+ "dataset-task" : LabelsExt .TASK .value
32
57
}
33
58
34
59
# put here to avoid recompiling, used only in _limit_context
@@ -63,6 +88,7 @@ class Experiment:
63
88
remove_num : bool = True
64
89
drop_duplicates : bool = True
65
90
mark_this_paper : bool = False
91
+ distinguish_model_source : bool = True
66
92
67
93
results : dict = dataclasses .field (default_factory = dict )
68
94
@@ -219,6 +245,8 @@ def _transform_df(self, df):
219
245
df = df .replace (re .compile (r"(^|[ ])\d+(\b|%)" ), " xxnum " )
220
246
df = df .replace (re .compile (r"\bdata set\b" ), " dataset " )
221
247
df ["label" ] = df ["cell_type" ].apply (lambda x : label_map .get (x , 0 ))
248
+ if not self .distinguish_model_source :
249
+ df ["label" ] = df ["label" ].apply (lambda x : x if x != Labels .COMPETING_MODEL .value else Labels .PAPER_MODEL .value )
222
250
df ["label" ] = pd .Categorical (df ["label" ])
223
251
return df
224
252
@@ -228,13 +256,15 @@ def transform_df(self, *dfs):
228
256
return transformed [0 ]
229
257
return transformed
230
258
231
- def _set_results (self , prefix , preds , true_y ):
259
+ def _set_results (self , prefix , preds , true_y , true_y_ext = None ):
232
260
m = metrics (preds , true_y )
233
261
r = {}
234
262
r [f"{ prefix } _accuracy" ] = m ["accuracy" ]
235
263
r [f"{ prefix } _precision" ] = m ["precision" ]
236
264
r [f"{ prefix } _recall" ] = m ["recall" ]
237
265
r [f"{ prefix } _cm" ] = confusion_matrix (true_y , preds , labels = [x .value for x in Labels ]).tolist ()
266
+ if true_y_ext is not None :
267
+ r [f"{ prefix } _cm_full" ] = confusion_matrix (true_y_ext , preds , labels = [x .value for x in LabelsExt ]).tolist ()
238
268
self .update_results (** r )
239
269
240
270
def evaluate (self , model , train_df , valid_df , test_df ):
@@ -253,17 +283,19 @@ def evaluate(self, model, train_df, valid_df, test_df):
253
283
true_y = vote_results ["true" ]
254
284
else :
255
285
true_y = tdf ["label" ]
256
- self ._set_results (prefix , preds , true_y )
286
+ true_y_ext = tdf ["cell_type" ].apply (lambda x : label_map_ext .get (x , 0 ))
287
+ self ._set_results (prefix , preds , true_y , true_y_ext )
257
288
258
- def show_results (self , * ds , normalize = True ):
289
+ def show_results (self , * ds , normalize = True , full_cm = True ):
259
290
if not len (ds ):
260
291
ds = ["train" , "valid" , "test" ]
261
292
for prefix in ds :
262
293
print (f"{ prefix } dataset" )
263
294
print (f" * accuracy: { self .results [f'{ prefix } _accuracy' ]:.3f} " )
264
295
print (f" * μ-precision: { self .results [f'{ prefix } _precision' ]:.3f} " )
265
296
print (f" * μ-recall: { self .results [f'{ prefix } _recall' ]:.3f} " )
266
- self ._plot_confusion_matrix (np .array (self .results [f'{ prefix } _cm' ]), normalize = normalize )
297
+ suffix = '_full' if full_cm and f'{ prefix } _cm_full' in self .results else ''
298
+ self ._plot_confusion_matrix (np .array (self .results [f'{ prefix } _cm{ suffix } ' ]), normalize = normalize )
267
299
268
300
def _plot_confusion_matrix (self , cm , normalize , fmt = None ):
269
301
if normalize :
@@ -272,7 +304,12 @@ def _plot_confusion_matrix(self, cm, normalize, fmt=None):
272
304
cm = cm / s
273
305
if fmt is None :
274
306
fmt = "0.2f" if normalize else "d"
275
- target_names = ["OTHER" , "DATASET" , "MODEL (paper)" , "MODEL (comp.)" , "METRIC" , "EMPTY" ]
307
+
308
+ if len (cm ) == 6 :
309
+ target_names = ["OTHER" , "DATASET" , "MODEL (paper)" , "MODEL (comp.)" , "METRIC" , "EMPTY" ]
310
+ else :
311
+ target_names = ["OTHER" , "params" , "task" , "DATASET" , "subdataset" , "MODEL (paper)" , "model (best)" ,
312
+ "model (ens.)" , "MODEL (comp.)" , "METRIC" , "EMPTY" ]
276
313
df_cm = pd .DataFrame (cm , index = [i for i in target_names ],
277
314
columns = [i for i in target_names ])
278
315
plt .figure (figsize = (10 , 10 ))
0 commit comments