Skip to content

Commit 8efa2c1

Browse files
author
sfluegel
committed
add features for semantic evaluation: selecting single metrics, postprocessing that removes violations, verbose output
1 parent e087dc7 commit 8efa2c1

File tree

1 file changed

+196
-15
lines changed

1 file changed

+196
-15
lines changed

chebai/result/analyse_sem.py

Lines changed: 196 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,118 @@ def load_preds_labels_from_nonwandb(
139139
return preds, labels
140140

141141

142+
def get_label_names(data_module):
143+
if os.path.exists(os.path.join(data_module.raw_dir, "classes.txt")):
144+
with open(os.path.join(data_module.raw_dir, "classes.txt")) as fin:
145+
return [int(line.strip()) for line in fin]
146+
return None
147+
148+
149+
def get_chebi_graph(data_module, label_names):
150+
if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")):
151+
chebi_graph = data_module.extract_class_hierarchy(
152+
os.path.join(data_module.raw_dir, "chebi.obo")
153+
)
154+
return chebi_graph.subgraph(label_names)
155+
return None
156+
157+
158+
def get_disjoint_groups():
159+
disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl")
160+
with open(disjoints_owl_file, "r") as f:
161+
plaintext = f.read()
162+
segments = plaintext.split("<")
163+
disjoint_pairs = []
164+
left = None
165+
for seg in segments:
166+
if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"):
167+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
168+
elif seg.startswith("owl:disjointWith"):
169+
right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0])
170+
disjoint_pairs.append([left, right])
171+
172+
disjoint_groups = []
173+
for seg in plaintext.split("<rdf:Description>"):
174+
if "owl;AllDisjointClasses" in seg:
175+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
176+
classes = [int(c.split('"')[0]) for c in classes]
177+
disjoint_groups.append(classes)
178+
disjoint_all = disjoint_pairs + disjoint_groups
179+
# one disjointness is commented out in the owl-file
180+
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
181+
disjoint_all.remove([22729, 51880])
182+
print(f"Found {len(disjoint_all)} disjoint groups")
183+
return disjoint_all
184+
185+
186+
def smooth_preds(preds, label_names, chebi_graph, disjoint_groups):
187+
preds_sum_orig = torch.sum(preds)
188+
print(f"Preds sum: {preds_sum_orig}")
189+
# eliminate implication violations by setting each prediction to maximum of its successors
190+
for i, label in enumerate(label_names):
191+
succs = [label_names.index(p) for p in chebi_graph.successors(label)] + [i]
192+
if len(succs) > 0:
193+
preds[:, i] = torch.max(preds[:, succs], dim=1).values
194+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
195+
preds_sum_orig = torch.sum(preds)
196+
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
197+
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
198+
for disj_group in disjoint_groups:
199+
disj_group = [label_names.index(g) for g in disj_group if g in label_names]
200+
if len(disj_group) > 1:
201+
old_preds = preds[:, disj_group]
202+
disj_max = torch.max(preds[:, disj_group], dim=1)
203+
for i, row in enumerate(preds):
204+
for l in range(len(preds[i])):
205+
if l in disj_group and l != disj_group[disj_max.indices[i]]:
206+
preds[i, l] = preds_bounded[i, l]
207+
samples_changed = 0
208+
for i, row in enumerate(preds[:, disj_group]):
209+
if any(r != o for r, o in zip(row, old_preds[i])):
210+
samples_changed += 1
211+
if samples_changed != 0:
212+
print(
213+
f"disjointness group {[label_names[d] for d in disj_group]} changed {samples_changed} samples"
214+
)
215+
print(
216+
f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}"
217+
)
218+
preds_sum_orig = torch.sum(preds)
219+
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
220+
for i, label in enumerate(label_names):
221+
predecessors = [i] + [
222+
label_names.index(p) for p in chebi_graph.predecessors(label)
223+
]
224+
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
225+
preds[:, i] = lowest_predecessors.values
226+
for idx_idx, idx in enumerate(lowest_predecessors.indices):
227+
if idx > 0:
228+
print(
229+
f"class {label}: changed prediction of sample {idx_idx} to value of class "
230+
f"{label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})"
231+
)
232+
if torch.sum(preds) != preds_sum_orig:
233+
print(
234+
f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}"
235+
)
236+
preds_sum_orig = torch.sum(preds)
237+
return preds
238+
239+
142240
def analyse_run(
143241
preds,
144242
labels,
145243
df_hyperparams, # parameters that are the independent of the semantic loss function used
146244
labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations
147245
chebi_version=231,
148246
results_path=os.path.join("_semantic", "eval_results.csv"),
247+
violation_metrics: [str | list[callable]] = "all",
248+
verbose_violation_output=False,
149249
):
150250
"""Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),
151251
saves results to csv"""
252+
if violation_metrics == "all":
253+
violation_metrics = [product, lukasiewicz, weak, strict, binary]
152254
data_module_labeled = labeled_data_cls(chebi_version=chebi_version)
153255
n_labels = preds.size(1)
154256
print(f"Found {preds.shape[0]} predictions ({n_labels} classes)")
@@ -173,7 +275,7 @@ def analyse_run(
173275
del preds_exp
174276
gc.collect()
175277

176-
for i, metric in enumerate([product, lukasiewicz, weak, strict, binary]):
278+
for i, metric in enumerate(violation_metrics):
177279
if filter_type == "impl":
178280
df_new.append(df_hyperparams.copy())
179281
df_new[-1]["metric"] = metric.__name__
@@ -188,6 +290,27 @@ def analyse_run(
188290
m["fns"] = apply_metric(
189291
metric, l_preds, 1 - r_preds if filter_type == "impl" else r_preds
190292
)
293+
if verbose_violation_output:
294+
label_names = get_label_names(data_module_labeled)
295+
print(f"Found {torch.sum(m['fns'])} {filter_type}-violations")
296+
# for k, fn_cls in enumerate(m['fns']):
297+
# if fn_cls > 0:
298+
# print(f"\tThereof, {fn_cls.item()} belong to class {label_names[k]}")
299+
if torch.sum(m["fns"]) != 0:
300+
fns = metric(
301+
l_preds, 1 - r_preds if filter_type == "impl" else r_preds
302+
)
303+
print(fns.shape)
304+
for k, row in enumerate(fns):
305+
if torch.sum(row) != 0:
306+
print(f"{torch.sum(row)} violations for entity {k}")
307+
for j, violation in enumerate(row):
308+
if violation > 0:
309+
print(
310+
f"\tviolated ({label_names[dl_filter_l[j]]} -> {preds[k, dl_filter_l[j]]:.3f}"
311+
f", {label_names[dl_filter_r[j]]} -> {preds[k, dl_filter_r[j]]:.3f})"
312+
)
313+
191314
m_cls = {}
192315
for key, value in m.items():
193316
m_cls[key] = _sort_results_by_label(
@@ -259,14 +382,23 @@ def run_all(
259382
skip_analyse=False,
260383
skip_preds=False,
261384
nonwandb_runs=None,
385+
violation_metrics="all",
386+
remove_violations=False,
262387
):
263388
# evaluate a list of runs on Hazardous and ChEBIOver100 datasets
264389
if datasets is None:
265390
datasets = [(Hazardous, "all"), (ChEBIOver100, "test")]
266391
timestamp = datetime.now().strftime("%y%m%d-%H%M")
267392
results_path = os.path.join(
268-
"_semloss_eval", f"semloss_results_pc-dis-200k_{timestamp}.csv"
393+
"_semloss_eval",
394+
f"semloss_results_pc-dis-200k_{timestamp}{'_violations_removed' if remove_violations else ''}.csv",
395+
)
396+
label_names = get_label_names(ChEBIOver100(chebi_version=chebi_version))
397+
chebi_graph = get_chebi_graph(
398+
ChEBIOver100(chebi_version=chebi_version), label_names
269399
)
400+
disjoint_groups = get_disjoint_groups()
401+
270402
api = wandb.Api()
271403
for run_id in run_ids:
272404
try:
@@ -280,17 +412,50 @@ def run_all(
280412
"data_module": test_on.__name__,
281413
"chebi_version": chebi_version,
282414
}
283-
if not skip_preds:
284-
preds, labels = load_preds_labels_from_wandb(
285-
run, epoch, chebi_version, test_on, kind
415+
buffer_dir_smoothed = os.path.join(
416+
"results_buffer",
417+
"smoothed3step",
418+
f"{run.name}_ep{epoch}",
419+
f"{test_on.__name__}_{kind}",
420+
)
421+
if remove_violations and os.path.exists(
422+
os.path.join(buffer_dir_smoothed, "preds000.pt")
423+
):
424+
preds = torch.load(
425+
os.path.join(buffer_dir_smoothed, "preds000.pt"), DEVICE
286426
)
427+
labels = None
287428
else:
288-
buffer_dir = os.path.join(
289-
"results_buffer",
290-
f"{run.name}_ep{epoch}",
291-
f"{test_on.__name__}_{kind}",
292-
)
293-
preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)
429+
if not skip_preds:
430+
preds, labels = load_preds_labels_from_wandb(
431+
run, epoch, chebi_version, test_on, kind
432+
)
433+
else:
434+
buffer_dir = os.path.join(
435+
"results_buffer",
436+
f"{run.name}_ep{epoch}",
437+
f"{test_on.__name__}_{kind}",
438+
)
439+
preds, labels = load_results_from_buffer(
440+
buffer_dir, device=DEVICE
441+
)
442+
assert (
443+
preds is not None
444+
), f"Did not find predictions in dir {buffer_dir}"
445+
if remove_violations:
446+
preds = smooth_preds(
447+
preds, label_names, chebi_graph, disjoint_groups
448+
)
449+
buffer_dir_smoothed = os.path.join(
450+
"results_buffer",
451+
"smoothed3step",
452+
f"{run.name}_ep{epoch}",
453+
f"{test_on.__name__}_{kind}",
454+
)
455+
os.makedirs(buffer_dir_smoothed, exist_ok=True)
456+
torch.save(
457+
preds, os.path.join(buffer_dir_smoothed, "preds000.pt")
458+
)
294459
if not skip_analyse:
295460
print(
296461
f"Calculating metrics for run {run.name} on {test_on.__name__} ({kind})"
@@ -301,6 +466,8 @@ def run_all(
301466
df_hyperparams=df,
302467
chebi_version=chebi_version,
303468
results_path=results_path,
469+
violation_metrics=violation_metrics,
470+
verbose_violation_output=True,
304471
)
305472
except Exception as e:
306473
print(f"Failed for run {run_id}: {e}")
@@ -330,6 +497,13 @@ def run_all(
330497
preds, labels = load_results_from_buffer(
331498
buffer_dir, device=DEVICE
332499
)
500+
assert (
501+
preds is not None
502+
), f"Did not find predictions in dir {buffer_dir}"
503+
if remove_violations:
504+
preds = smooth_preds(
505+
preds, label_names, chebi_graph, disjoint_groups
506+
)
333507
if not skip_analyse:
334508
print(
335509
f"Calculating metrics for run {run_name} on {test_on.__name__} ({kind})"
@@ -340,16 +514,15 @@ def run_all(
340514
df_hyperparams=df,
341515
chebi_version=chebi_version,
342516
results_path=results_path,
517+
violation_metrics=violation_metrics,
343518
)
344519
except Exception as e:
345520
print(f"Failed for run {run_name}: {e}")
346521
print(traceback.format_exc())
347522

348523

349524
def run_semloss_eval(mode="eval"):
350-
non_wandb_runs = (
351-
[]
352-
) # ("chebi100_semprodk2_weighted_v231_pc_200k_dis_24042-2000", 195)]
525+
non_wandb_runs = []
353526
if mode == "preds":
354527
api = wandb.Api()
355528
runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"})
@@ -375,8 +548,16 @@ def run_semloss_eval(mode="eval"):
375548
"tk15yznc",
376549
]
377550
baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"]
551+
k2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"]
378552
ids = baseline
379-
run_all(ids, skip_preds=True, nonwandb_runs=non_wandb_runs)
553+
run_all(
554+
ids,
555+
skip_preds=True,
556+
nonwandb_runs=non_wandb_runs,
557+
datasets=[(ChEBIOver100, "test")],
558+
violation_metrics=[binary],
559+
remove_violations=True,
560+
)
380561

381562

382563
if __name__ == "__main__":

0 commit comments

Comments
 (0)