Skip to content

Commit ea28280

Browse files
authored
Merge branch 'dev' into fix/save_out_dim_to_checkpoint
2 parents 1285e5a + 1a13718 commit ea28280

File tree

9 files changed

+373
-63
lines changed

9 files changed

+373
-63
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,4 @@ chebai.egg-info
175175
lightning_logs
176176
logs
177177
.isort.cfg
178+
/.vscode

chebai/models/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
2+
from abc import ABC, abstractmethod
23
from typing import Any, Dict, Iterable, Optional, Union
34

45
import torch
56
from lightning.pytorch.core.module import LightningModule
6-
from torchmetrics import Metric
77

88
from chebai.preprocessing.structures import XYData
99

@@ -12,7 +12,7 @@
1212
_MODEL_REGISTRY = dict()
1313

1414

15-
class ChebaiBaseNet(LightningModule):
15+
class ChebaiBaseNet(LightningModule, ABC):
1616
"""
1717
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
1818
@@ -356,6 +356,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
356356
logger=True,
357357
)
358358

359+
@abstractmethod
359360
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
360361
"""
361362
Defines the forward pass.
@@ -366,7 +367,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
366367
Returns:
367368
torch.Tensor: The model output.
368369
"""
369-
raise NotImplementedError
370+
pass
370371

371372
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
372373
"""

chebai/preprocessing/bin/graph_properties/tokens.txt

Whitespace-only changes.

chebai/preprocessing/datasets/base.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class XYBaseDataModule(LightningDataModule):
2929
3030
Args:
3131
batch_size (int): The batch size for data loading. Default is 1.
32-
train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
32+
test_split (float): The ratio of test data to total data. Default is 0.1.
33+
validation_split (float): The ratio of validation data to total data. Default is 0.05.
3334
reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
3435
prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
3536
data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
@@ -45,7 +46,8 @@ class XYBaseDataModule(LightningDataModule):
4546
Attributes:
4647
READER (DataReader): The data reader class to use.
4748
reader (DataReader): An instance of the data reader class.
48-
train_split (float): The ratio of training data to total data.
49+
test_split (float): The ratio of test data to total data.
50+
validation_split (float): The ratio of validation data to total data.
4951
batch_size (int): The batch size for data loading.
5052
prediction_kind (str): The kind of prediction to be performed.
5153
data_limit (Optional[int]): The maximum number of data samples to load.
@@ -68,7 +70,8 @@ class XYBaseDataModule(LightningDataModule):
6870
def __init__(
6971
self,
7072
batch_size: int = 1,
71-
train_split: float = 0.85,
73+
test_split: Optional[float] = 0.1,
74+
validation_split: Optional[float] = 0.05,
7275
reader_kwargs: Optional[dict] = None,
7376
prediction_kind: str = "test",
7477
data_limit: Optional[int] = None,
@@ -86,7 +89,9 @@ def __init__(
8689
if reader_kwargs is None:
8790
reader_kwargs = dict()
8891
self.reader = self.READER(**reader_kwargs)
89-
self.train_split = train_split
92+
self.test_split = test_split
93+
self.validation_split = validation_split
94+
9095
self.batch_size = batch_size
9196
self.prediction_kind = prediction_kind
9297
self.data_limit = data_limit
@@ -1022,15 +1027,13 @@ def get_test_split(
10221027

10231028
labels_list = df["labels"].tolist()
10241029

1025-
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
1026-
10271030
if len(labels_list[0]) > 1:
10281031
splitter = MultilabelStratifiedShuffleSplit(
1029-
n_splits=1, test_size=test_size, random_state=seed
1032+
n_splits=1, test_size=self.test_split, random_state=seed
10301033
)
10311034
else:
10321035
splitter = StratifiedShuffleSplit(
1033-
n_splits=1, test_size=test_size, random_state=seed
1036+
n_splits=1, test_size=self.test_split, random_state=seed
10341037
)
10351038

10361039
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
@@ -1083,16 +1086,17 @@ def get_train_val_splits_given_test(
10831086

10841087
return folds
10851088

1086-
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
1087-
test_size = ((1 - self.train_split) ** 2) / self.train_split
1088-
10891089
if len(labels_list_trainval[0]) > 1:
10901090
splitter = MultilabelStratifiedShuffleSplit(
1091-
n_splits=1, test_size=test_size, random_state=seed
1091+
n_splits=1,
1092+
test_size=self.validation_split / (1 - self.test_split),
1093+
random_state=seed,
10921094
)
10931095
else:
10941096
splitter = StratifiedShuffleSplit(
1095-
n_splits=1, test_size=test_size, random_state=seed
1097+
n_splits=1,
1098+
test_size=self.validation_split / (1 - self.test_split),
1099+
random_state=seed,
10961100
)
10971101

10981102
train_indices, validation_indices = next(

chebai/result/analyse_sem.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import gc
2-
import sys
32
import traceback
43
from datetime import datetime
54
from typing import List, LiteralString
65

6+
import pandas as pd
77
from torchmetrics.functional.classification import (
88
multilabel_auroc,
99
multilabel_average_precision,
1010
multilabel_f1_score,
1111
)
12-
from utils import *
1312

1413
from chebai.loss.semantic import DisjointLoss
14+
from chebai.models import Electra
1515
from chebai.preprocessing.datasets.base import _DynamicDataset
1616
from chebai.preprocessing.datasets.chebi import ChEBIOver100
1717
from chebai.preprocessing.datasets.pubchem import PubChemKMeans
18+
from chebai.result.utils import *
1819

1920
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2021

@@ -122,7 +123,7 @@ def load_preds_labels(
122123
def get_label_names(data_module):
123124
if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")):
124125
with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin:
125-
return [int(line.strip()) for line in fin]
126+
return [line.strip() for line in fin]
126127
print(
127128
f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found"
128129
)
@@ -131,70 +132,97 @@ def get_label_names(data_module):
131132

132133
def get_chebi_graph(data_module, label_names):
133134
if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")):
134-
chebi_graph = data_module.extract_class_hierarchy(
135+
chebi_graph = data_module._extract_class_hierarchy(
135136
os.path.join(data_module.raw_dir, "chebi.obo")
136137
)
137-
return chebi_graph.subgraph(label_names)
138+
return chebi_graph.subgraph([int(n) for n in label_names])
138139
print(
139140
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
140141
)
141142
return None
142143

143144

144-
def get_disjoint_groups():
145-
disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl")
146-
with open(disjoints_owl_file, "r") as f:
147-
plaintext = f.read()
148-
segments = plaintext.split("<")
149-
disjoint_pairs = []
150-
left = None
151-
for seg in segments:
152-
if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"):
153-
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
154-
elif seg.startswith("owl:disjointWith"):
155-
right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0])
156-
disjoint_pairs.append([left, right])
157-
158-
disjoint_groups = []
159-
for seg in plaintext.split("<rdf:Description>"):
160-
if "owl;AllDisjointClasses" in seg:
161-
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
162-
classes = [int(c.split('"')[0]) for c in classes]
163-
disjoint_groups.append(classes)
145+
def get_disjoint_groups(disjoint_files):
146+
if disjoint_files is None:
147+
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
148+
disjoint_pairs, disjoint_groups = [], []
149+
for file in disjoint_files:
150+
if file.split(".")[-1] == "csv":
151+
disjoint_pairs += pd.read_csv(file, header=None).values.tolist()
152+
elif file.split(".")[-1] == "owl":
153+
with open(file, "r") as f:
154+
plaintext = f.read()
155+
segments = plaintext.split("<")
156+
disjoint_pairs = []
157+
left = None
158+
for seg in segments:
159+
if seg.startswith("rdf:Description ") or seg.startswith(
160+
"owl:Class"
161+
):
162+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
163+
elif seg.startswith("owl:disjointWith"):
164+
right = int(
165+
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
166+
)
167+
disjoint_pairs.append([left, right])
168+
169+
disjoint_groups = []
170+
for seg in plaintext.split("<rdf:Description>"):
171+
if "owl;AllDisjointClasses" in seg:
172+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
173+
classes = [int(c.split('"')[0]) for c in classes]
174+
disjoint_groups.append(classes)
175+
else:
176+
raise NotImplementedError(
177+
"Unsupported disjoint file format: " + file.split(".")[-1]
178+
)
179+
164180
disjoint_all = disjoint_pairs + disjoint_groups
165181
# one disjointness is commented out in the owl-file
166182
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
167-
disjoint_all.remove([22729, 51880])
168-
print(f"Found {len(disjoint_all)} disjoint groups")
183+
if [22729, 51880] in disjoint_all:
184+
disjoint_all.remove([22729, 51880])
185+
# print(f"Found {len(disjoint_all)} disjoint groups")
169186
return disjoint_all
170187

171188

172189
class PredictionSmoother:
173190
"""Removes implication and disjointness violations from predictions"""
174191

175-
def __init__(self, dataset):
176-
self.label_names = get_label_names(dataset)
192+
def __init__(self, dataset, label_names=None, disjoint_files=None):
193+
if label_names:
194+
self.label_names = label_names
195+
else:
196+
self.label_names = get_label_names(dataset)
177197
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
178-
self.disjoint_groups = get_disjoint_groups()
198+
self.disjoint_groups = get_disjoint_groups(disjoint_files)
179199

180200
def __call__(self, preds):
181-
182201
preds_sum_orig = torch.sum(preds)
183-
print(f"Preds sum: {preds_sum_orig}")
184-
# eliminate implication violations by setting each prediction to maximum of its successors
185202
for i, label in enumerate(self.label_names):
186203
succs = [
187-
self.label_names.index(p) for p in self.chebi_graph.successors(label)
204+
self.label_names.index(str(p))
205+
for p in self.chebi_graph.successors(int(label))
188206
] + [i]
189207
if len(succs) > 0:
208+
if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
209+
print(
210+
f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
211+
)
212+
print(
213+
f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
214+
)
190215
preds[:, i] = torch.max(preds[:, succs], dim=1).values
191-
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
216+
if torch.sum(preds) != preds_sum_orig:
217+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
192218
preds_sum_orig = torch.sum(preds)
193219
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
194220
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
195221
for disj_group in self.disjoint_groups:
196222
disj_group = [
197-
self.label_names.index(g) for g in disj_group if g in self.label_names
223+
self.label_names.index(str(g))
224+
for g in disj_group
225+
if g in self.label_names
198226
]
199227
if len(disj_group) > 1:
200228
old_preds = preds[:, disj_group]
@@ -211,14 +239,12 @@ def __call__(self, preds):
211239
print(
212240
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
213241
)
214-
print(
215-
f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}"
216-
)
217242
preds_sum_orig = torch.sum(preds)
218243
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
219244
for i, label in enumerate(self.label_names):
220245
predecessors = [i] + [
221-
self.label_names.index(p) for p in self.chebi_graph.predecessors(label)
246+
self.label_names.index(str(p))
247+
for p in self.chebi_graph.predecessors(int(label))
222248
]
223249
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
224250
preds[:, i] = lowest_predecessors.values

0 commit comments

Comments
 (0)