Skip to content

Commit 4f1f995

Browse files
authored
Merge branch 'dev' into feature/pyproject.toml
2 parents 1083a22 + 1e2a043 commit 4f1f995

File tree

3 files changed

+103
-68
lines changed

3 files changed

+103
-68
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class XYBaseDataModule(LightningDataModule):
2929
3030
Args:
3131
batch_size (int): The batch size for data loading. Default is 1.
32-
train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
32+
test_split (float): The ratio of test data to total data. Default is 0.1.
33+
validation_split (float): The ratio of validation data to total data. Default is 0.05.
3334
reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
3435
prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
3536
data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
@@ -45,7 +46,8 @@ class XYBaseDataModule(LightningDataModule):
4546
Attributes:
4647
READER (DataReader): The data reader class to use.
4748
reader (DataReader): An instance of the data reader class.
48-
train_split (float): The ratio of training data to total data.
49+
test_split (float): The ratio of test data to total data.
50+
validation_split (float): The ratio of validation data to total data.
4951
batch_size (int): The batch size for data loading.
5052
prediction_kind (str): The kind of prediction to be performed.
5153
data_limit (Optional[int]): The maximum number of data samples to load.
@@ -68,7 +70,8 @@ class XYBaseDataModule(LightningDataModule):
6870
def __init__(
6971
self,
7072
batch_size: int = 1,
71-
train_split: float = 0.85,
73+
test_split: Optional[float] = 0.1,
74+
validation_split: Optional[float] = 0.05,
7275
reader_kwargs: Optional[dict] = None,
7376
prediction_kind: str = "test",
7477
data_limit: Optional[int] = None,
@@ -86,7 +89,9 @@ def __init__(
8689
if reader_kwargs is None:
8790
reader_kwargs = dict()
8891
self.reader = self.READER(**reader_kwargs)
89-
self.train_split = train_split
92+
self.test_split = test_split
93+
self.validation_split = validation_split
94+
9095
self.batch_size = batch_size
9196
self.prediction_kind = prediction_kind
9297
self.data_limit = data_limit
@@ -1022,15 +1027,13 @@ def get_test_split(
10221027

10231028
labels_list = df["labels"].tolist()
10241029

1025-
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
1026-
10271030
if len(labels_list[0]) > 1:
10281031
splitter = MultilabelStratifiedShuffleSplit(
1029-
n_splits=1, test_size=test_size, random_state=seed
1032+
n_splits=1, test_size=self.test_split, random_state=seed
10301033
)
10311034
else:
10321035
splitter = StratifiedShuffleSplit(
1033-
n_splits=1, test_size=test_size, random_state=seed
1036+
n_splits=1, test_size=self.test_split, random_state=seed
10341037
)
10351038

10361039
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
@@ -1083,16 +1086,17 @@ def get_train_val_splits_given_test(
10831086

10841087
return folds
10851088

1086-
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
1087-
test_size = ((1 - self.train_split) ** 2) / self.train_split
1088-
10891089
if len(labels_list_trainval[0]) > 1:
10901090
splitter = MultilabelStratifiedShuffleSplit(
1091-
n_splits=1, test_size=test_size, random_state=seed
1091+
n_splits=1,
1092+
test_size=self.validation_split / (1 - self.test_split),
1093+
random_state=seed,
10921094
)
10931095
else:
10941096
splitter = StratifiedShuffleSplit(
1095-
n_splits=1, test_size=test_size, random_state=seed
1097+
n_splits=1,
1098+
test_size=self.validation_split / (1 - self.test_split),
1099+
random_state=seed,
10961100
)
10971101

10981102
train_indices, validation_indices = next(

chebai/result/analyse_sem.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
import gc
2-
import os
3-
import sys
42
import traceback
53
from datetime import datetime
6-
from typing import List, LiteralString, Optional, Tuple
4+
from typing import List, LiteralString
75

8-
import torch
9-
import wandb
6+
import pandas as pd
107
from torchmetrics.functional.classification import (
118
multilabel_auroc,
129
multilabel_average_precision,
1310
multilabel_f1_score,
1411
)
15-
from utils import evaluate_model, get_checkpoint_from_wandb, load_results_from_buffer
1612

1713
from chebai.loss.semantic import DisjointLoss
1814
from chebai.models import Electra
1915
from chebai.preprocessing.datasets.base import _DynamicDataset
2016
from chebai.preprocessing.datasets.chebi import ChEBIOver100
21-
22-
# from chebai.preprocessing.datasets.pubchem import PubChemKMeans
17+
from chebai.preprocessing.datasets.pubchem import PubChemKMeans
18+
from chebai.result.utils import *
2319

2420
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2521

@@ -127,7 +123,7 @@ def load_preds_labels(
127123
def get_label_names(data_module):
128124
if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")):
129125
with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin:
130-
return [int(line.strip()) for line in fin]
126+
return [line.strip() for line in fin]
131127
print(
132128
f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found"
133129
)
@@ -136,69 +132,97 @@ def get_label_names(data_module):
136132

137133
def get_chebi_graph(data_module, label_names):
138134
if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")):
139-
chebi_graph = data_module.extract_class_hierarchy(
135+
chebi_graph = data_module._extract_class_hierarchy(
140136
os.path.join(data_module.raw_dir, "chebi.obo")
141137
)
142-
return chebi_graph.subgraph(label_names)
138+
return chebi_graph.subgraph([int(n) for n in label_names])
143139
print(
144140
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
145141
)
146142
return None
147143

148144

149-
def get_disjoint_groups():
150-
disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl")
151-
with open(disjoints_owl_file, "r") as f:
152-
plaintext = f.read()
153-
segments = plaintext.split("<")
154-
disjoint_pairs = []
155-
left = None
156-
for seg in segments:
157-
if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"):
158-
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
159-
elif seg.startswith("owl:disjointWith"):
160-
right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0])
161-
disjoint_pairs.append([left, right])
162-
163-
disjoint_groups = []
164-
for seg in plaintext.split("<rdf:Description>"):
165-
if "owl;AllDisjointClasses" in seg:
166-
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
167-
classes = [int(c.split('"')[0]) for c in classes]
168-
disjoint_groups.append(classes)
145+
def get_disjoint_groups(disjoint_files):
146+
if disjoint_files is None:
147+
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
148+
disjoint_pairs, disjoint_groups = [], []
149+
for file in disjoint_files:
150+
if file.split(".")[-1] == "csv":
151+
disjoint_pairs += pd.read_csv(file, header=None).values.tolist()
152+
elif file.split(".")[-1] == "owl":
153+
with open(file, "r") as f:
154+
plaintext = f.read()
155+
segments = plaintext.split("<")
156+
disjoint_pairs = []
157+
left = None
158+
for seg in segments:
159+
if seg.startswith("rdf:Description ") or seg.startswith(
160+
"owl:Class"
161+
):
162+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
163+
elif seg.startswith("owl:disjointWith"):
164+
right = int(
165+
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
166+
)
167+
disjoint_pairs.append([left, right])
168+
169+
disjoint_groups = []
170+
for seg in plaintext.split("<rdf:Description>"):
171+
if "owl;AllDisjointClasses" in seg:
172+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
173+
classes = [int(c.split('"')[0]) for c in classes]
174+
disjoint_groups.append(classes)
175+
else:
176+
raise NotImplementedError(
177+
"Unsupported disjoint file format: " + file.split(".")[-1]
178+
)
179+
169180
disjoint_all = disjoint_pairs + disjoint_groups
170181
# one disjointness is commented out in the owl-file
171182
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
172-
disjoint_all.remove([22729, 51880])
173-
print(f"Found {len(disjoint_all)} disjoint groups")
183+
if [22729, 51880] in disjoint_all:
184+
disjoint_all.remove([22729, 51880])
185+
# print(f"Found {len(disjoint_all)} disjoint groups")
174186
return disjoint_all
175187

176188

177189
class PredictionSmoother:
178190
"""Removes implication and disjointness violations from predictions"""
179191

180-
def __init__(self, dataset):
181-
self.label_names = get_label_names(dataset)
192+
def __init__(self, dataset, label_names=None, disjoint_files=None):
193+
if label_names:
194+
self.label_names = label_names
195+
else:
196+
self.label_names = get_label_names(dataset)
182197
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
183-
self.disjoint_groups = get_disjoint_groups()
198+
self.disjoint_groups = get_disjoint_groups(disjoint_files)
184199

185200
def __call__(self, preds):
186201
preds_sum_orig = torch.sum(preds)
187-
print(f"Preds sum: {preds_sum_orig}")
188-
# eliminate implication violations by setting each prediction to maximum of its successors
189202
for i, label in enumerate(self.label_names):
190203
succs = [
191-
self.label_names.index(p) for p in self.chebi_graph.successors(label)
204+
self.label_names.index(str(p))
205+
for p in self.chebi_graph.successors(int(label))
192206
] + [i]
193207
if len(succs) > 0:
208+
if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
209+
print(
210+
f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
211+
)
212+
print(
213+
f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
214+
)
194215
preds[:, i] = torch.max(preds[:, succs], dim=1).values
195-
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
216+
if torch.sum(preds) != preds_sum_orig:
217+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
196218
preds_sum_orig = torch.sum(preds)
197219
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
198220
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
199221
for disj_group in self.disjoint_groups:
200222
disj_group = [
201-
self.label_names.index(g) for g in disj_group if g in self.label_names
223+
self.label_names.index(str(g))
224+
for g in disj_group
225+
if g in self.label_names
202226
]
203227
if len(disj_group) > 1:
204228
old_preds = preds[:, disj_group]
@@ -215,14 +239,12 @@ def __call__(self, preds):
215239
print(
216240
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
217241
)
218-
print(
219-
f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}"
220-
)
221242
preds_sum_orig = torch.sum(preds)
222243
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
223244
for i, label in enumerate(self.label_names):
224245
predecessors = [i] + [
225-
self.label_names.index(p) for p in self.chebi_graph.predecessors(label)
246+
self.label_names.index(str(p))
247+
for p in self.chebi_graph.predecessors(int(label))
226248
]
227249
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
228250
preds[:, i] = lowest_predecessors.values

chebai/result/_generate_classes_props_json.py renamed to chebai/result/generate_class_properties.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
class ClassesPropertiesGenerator:
1717
"""
18-
Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value)
18+
Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value) and counts the number of
19+
true positives (TP), false positives (FP), true negatives (TN), and false negatives (FN)
1920
for each class in a multi-label classification problem using a PyTorch Lightning model.
2021
"""
2122

@@ -35,23 +36,25 @@ def load_class_labels(path: Path) -> list[str]:
3536
return [line.strip() for line in f if line.strip()]
3637

3738
@staticmethod
38-
def compute_tpv_npv(
39+
def compute_classwise_scores(
3940
y_true: list[torch.Tensor],
4041
y_pred: list[torch.Tensor],
42+
raw_preds: torch.Tensor,
4143
class_names: list[str],
4244
) -> dict[str, dict[str, float]]:
4345
"""
44-
Compute TPV (precision) and NPV for each class in a multi-label setting.
46+
Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
47+
in a multi-label setting.
4548
4649
Args:
4750
y_true: List of binary ground-truth label tensors, one tensor per sample.
4851
y_pred: List of binary prediction tensors, one tensor per sample.
4952
class_names: Ordered list of class names corresponding to class indices.
5053
5154
Returns:
52-
Dictionary mapping each class name to its TPV and NPV metrics:
55+
Dictionary mapping each class name to its PPV and NPV metrics:
5356
{
54-
"class_name": {"PPV": float, "NPV": float},
57+
"class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int},
5558
...
5659
}
5760
"""
@@ -67,13 +70,17 @@ def compute_tpv_npv(
6770
tn, fp, fn, tp = cm[idx].ravel()
6871
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
6972
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
73+
# positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
74+
# negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
7075
results[cls_name] = {
7176
"PPV": round(tpv, 4),
7277
"NPV": round(npv, 4),
7378
"TN": int(tn),
7479
"FP": int(fp),
7580
"FN": int(fn),
7681
"TP": int(tp),
82+
# "positive_preds": positive_raw,
83+
# "negative_preds": negative_raw,
7784
}
7885
return results
7986

@@ -125,6 +132,7 @@ def generate_props(
125132
print("Running inference on validation data...")
126133

127134
y_true, y_pred = [], []
135+
raw_preds = []
128136
for batch_idx, batch in enumerate(val_loader):
129137
data = model._process_batch( # pylint: disable=W0212
130138
batch, batch_idx=batch_idx
@@ -135,20 +143,21 @@ def generate_props(
135143
preds = torch.sigmoid(logits) > 0.5
136144
y_pred.extend(preds)
137145
y_true.extend(labels)
138-
139-
print("Computing TPV and NPV metrics...")
146+
raw_preds.extend(torch.sigmoid(logits))
147+
raw_preds = torch.stack(raw_preds)
148+
print("Computing metrics...")
140149
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
141150
if output_path is None:
142151
output_file = Path(data_module.processed_dir_main) / "classes.json"
143152
else:
144153
output_file = Path(output_path)
145154

146155
class_names = self.load_class_labels(classes_file)
147-
metrics = self.compute_tpv_npv(y_true, y_pred, class_names)
156+
metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names)
148157

149158
with output_file.open("w") as f:
150159
json.dump(metrics, f, indent=2)
151-
print(f"Saved TPV/NPV metrics to {output_file}")
160+
print(f"Saved metrics to {output_file}")
152161

153162

154163
class Main:
@@ -164,7 +173,7 @@ def generate(
164173
output_path: str | None = None,
165174
) -> None:
166175
"""
167-
CLI command to generate TPV/NPV JSON.
176+
CLI command to generate JSON with metrics on validation set.
168177
169178
Args:
170179
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.

0 commit comments

Comments
 (0)