Skip to content

Commit 1e128f1

Browse files
committed
Show extended info in confusion matrices
1 parent 15cfa0b commit 1e128f1

File tree

1 file changed

+44
-7
lines changed

1 file changed

+44
-7
lines changed

sota_extractor2/models/structure/experiment.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,40 @@ class Labels(Enum):
2020
EMPTY=5
2121

2222

23+
class LabelsExt(Enum):
24+
OTHER=0
25+
PARAMS=6
26+
TASK=7
27+
DATASET=1
28+
SUBDATASET=8
29+
PAPER_MODEL=2
30+
BEST_MODEL=9
31+
ENSEMBLE_MODEL=10
32+
COMPETING_MODEL=3
33+
METRIC=4
34+
EMPTY=5
35+
36+
2337
label_map = {
2438
"dataset": Labels.DATASET.value,
2539
"dataset-sub": Labels.DATASET.value,
2640
"model-paper": Labels.PAPER_MODEL.value,
2741
"model-best": Labels.PAPER_MODEL.value,
2842
"model-ensemble": Labels.PAPER_MODEL.value,
2943
"model-competing": Labels.COMPETING_MODEL.value,
30-
"dataset-metric": Labels.METRIC.value,
31-
# "model-params": Labels.PARAMS.value
44+
"dataset-metric": Labels.METRIC.value
45+
}
46+
47+
label_map_ext = {
48+
"dataset": LabelsExt.DATASET.value,
49+
"dataset-sub": LabelsExt.SUBDATASET.value,
50+
"model-paper": LabelsExt.PAPER_MODEL.value,
51+
"model-best": LabelsExt.BEST_MODEL.value,
52+
"model-ensemble": LabelsExt.ENSEMBLE_MODEL.value,
53+
"model-competing": LabelsExt.COMPETING_MODEL.value,
54+
"dataset-metric": LabelsExt.METRIC.value,
55+
"model-params": LabelsExt.PARAMS.value,
56+
"dataset-task": LabelsExt.TASK.value
3257
}
3358

3459
# put here to avoid recompiling, used only in _limit_context
@@ -63,6 +88,7 @@ class Experiment:
6388
remove_num: bool = True
6489
drop_duplicates: bool = True
6590
mark_this_paper: bool = False
91+
distinguish_model_source: bool = True
6692

6793
results: dict = dataclasses.field(default_factory=dict)
6894

@@ -219,6 +245,8 @@ def _transform_df(self, df):
219245
df = df.replace(re.compile(r"(^|[ ])\d+(\b|%)"), " xxnum ")
220246
df = df.replace(re.compile(r"\bdata set\b"), " dataset ")
221247
df["label"] = df["cell_type"].apply(lambda x: label_map.get(x, 0))
248+
if not self.distinguish_model_source:
249+
df["label"] = df["label"].apply(lambda x: x if x != Labels.COMPETING_MODEL.value else Labels.PAPER_MODEL.value)
222250
df["label"] = pd.Categorical(df["label"])
223251
return df
224252

@@ -228,13 +256,15 @@ def transform_df(self, *dfs):
228256
return transformed[0]
229257
return transformed
230258

231-
def _set_results(self, prefix, preds, true_y):
259+
def _set_results(self, prefix, preds, true_y, true_y_ext=None):
232260
m = metrics(preds, true_y)
233261
r = {}
234262
r[f"{prefix}_accuracy"] = m["accuracy"]
235263
r[f"{prefix}_precision"] = m["precision"]
236264
r[f"{prefix}_recall"] = m["recall"]
237265
r[f"{prefix}_cm"] = confusion_matrix(true_y, preds, labels=[x.value for x in Labels]).tolist()
266+
if true_y_ext is not None:
267+
r[f"{prefix}_cm_full"] = confusion_matrix(true_y_ext, preds, labels=[x.value for x in LabelsExt]).tolist()
238268
self.update_results(**r)
239269

240270
def evaluate(self, model, train_df, valid_df, test_df):
@@ -253,17 +283,19 @@ def evaluate(self, model, train_df, valid_df, test_df):
253283
true_y = vote_results["true"]
254284
else:
255285
true_y = tdf["label"]
256-
self._set_results(prefix, preds, true_y)
286+
true_y_ext = tdf["cell_type"].apply(lambda x: label_map_ext.get(x, 0))
287+
self._set_results(prefix, preds, true_y, true_y_ext)
257288

258-
def show_results(self, *ds, normalize=True):
289+
def show_results(self, *ds, normalize=True, full_cm=True):
259290
if not len(ds):
260291
ds = ["train", "valid", "test"]
261292
for prefix in ds:
262293
print(f"{prefix} dataset")
263294
print(f" * accuracy: {self.results[f'{prefix}_accuracy']:.3f}")
264295
print(f" * μ-precision: {self.results[f'{prefix}_precision']:.3f}")
265296
print(f" * μ-recall: {self.results[f'{prefix}_recall']:.3f}")
266-
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm']), normalize=normalize)
297+
suffix = '_full' if full_cm and f'{prefix}_cm_full' in self.results else ''
298+
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm{suffix}']), normalize=normalize)
267299

268300
def _plot_confusion_matrix(self, cm, normalize, fmt=None):
269301
if normalize:
@@ -272,7 +304,12 @@ def _plot_confusion_matrix(self, cm, normalize, fmt=None):
272304
cm = cm / s
273305
if fmt is None:
274306
fmt = "0.2f" if normalize else "d"
275-
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC", "EMPTY"]
307+
308+
if len(cm) == 6:
309+
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC", "EMPTY"]
310+
else:
311+
target_names = ["OTHER", "params", "task", "DATASET", "subdataset", "MODEL (paper)", "model (best)",
312+
"model (ens.)", "MODEL (comp.)", "METRIC", "EMPTY"]
276313
df_cm = pd.DataFrame(cm, index=[i for i in target_names],
277314
columns=[i for i in target_names])
278315
plt.figure(figsize=(10, 10))

0 commit comments

Comments
 (0)