Skip to content

Commit 5cbc5a0

Browse files
author
sfluegel
committed
Merge branch 'refs/heads/dev' into feature/improved-hyperparameter-tracking
# Conflicts: # configs/data/chebi100.yml
2 parents 660ac34 + 3ca5707 commit 5cbc5a0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1455
-66
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,77 @@ def compute(self):
4747
# if (precision and recall are 0) or (precision is nan), set f1 to 0
4848
classwise_f1 = classwise_f1.nan_to_num()
4949
return torch.mean(classwise_f1)
50+
51+
52+
class BalancedAccuracy(torchmetrics.Metric):
53+
"""Balanced Accuracy = (TPR + TNR) / 2 = ( TP/(TP + FN) + (TN)/(TN + FP) ) / 2
54+
55+
This metric computes the balanced accuracy, which is the average of true positive rate (TPR)
56+
and true negative rate (TNR). It is useful for imbalanced datasets where the classes are not
57+
represented equally.
58+
"""
59+
60+
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
61+
super().__init__(dist_sync_on_step=dist_sync_on_step)
62+
63+
self.add_state(
64+
"true_positives",
65+
default=torch.zeros(num_labels, dtype=torch.int),
66+
dist_reduce_fx="sum",
67+
)
68+
69+
self.add_state(
70+
"false_positives",
71+
default=torch.zeros(num_labels, dtype=torch.int),
72+
dist_reduce_fx="sum",
73+
)
74+
75+
self.add_state(
76+
"true_negatives",
77+
default=torch.zeros(num_labels, dtype=torch.int),
78+
dist_reduce_fx="sum",
79+
)
80+
81+
self.add_state(
82+
"false_negatives",
83+
default=torch.zeros(num_labels, dtype=torch.int),
84+
dist_reduce_fx="sum",
85+
)
86+
87+
self.threshold = threshold
88+
89+
def update(self, preds: torch.Tensor, labels: torch.Tensor):
90+
"""Update the TPs, TNs ,FPs and FNs"""
91+
92+
# Size: Batch_size x Num_of_Classes;
93+
# summing over 1st dimension (dim=0), gives us the True positives per class
94+
tps = torch.sum(
95+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
96+
)
97+
fps = torch.sum(
98+
torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0
99+
)
100+
tns = torch.sum(
101+
torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0
102+
)
103+
fns = torch.sum(
104+
torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0
105+
)
106+
107+
# Size: Num_of_Classes;
108+
self.true_positives += tps
109+
self.false_positives += fps
110+
self.true_negatives += tns
111+
self.false_negatives += fns
112+
113+
def compute(self):
114+
"""Compute the average value of Balanced accuracy from each batch"""
115+
116+
tpr = self.true_positives / (self.true_positives + self.false_negatives)
117+
tnr = self.true_negatives / (self.true_negatives + self.false_positives)
118+
# Convert the nan values to 0
119+
tpr = tpr.nan_to_num()
120+
tnr = tnr.nan_to_num()
121+
122+
balanced_acc = (tpr + tnr) / 2
123+
return torch.mean(balanced_acc)

chebai/cli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@ def __init__(self, *args, **kwargs):
1111

1212
def add_arguments_to_parser(self, parser: LightningArgumentParser):
1313
for kind in ("train", "val", "test"):
14-
for average in ("micro", "macro"):
14+
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
1515
parser.link_arguments(
1616
"model.init_args.out_dim",
17-
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
17+
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
1818
)
1919
parser.link_arguments(
2020
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
2121
)
22+
parser.link_arguments(
23+
"data", "model.init_args.criterion.init_args.data_extractor"
24+
)
2225

2326
@staticmethod
2427
def subcommands() -> Dict[str, Set[str]]:

chebai/loss/bce_weighted.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
2-
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
2+
from chebai.preprocessing.datasets.base import XYBaseDataModule
3+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
34
import pandas as pd
45
import os
56
import pickle
@@ -10,9 +11,16 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1011
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf)
1112
"""
1213

13-
def __init__(self, beta: float = None, data_extractor: _ChEBIDataExtractor = None):
14+
def __init__(
15+
self,
16+
beta: float = None,
17+
data_extractor: XYBaseDataModule = None,
18+
):
1419
self.beta = beta
20+
if isinstance(data_extractor, LabeledUnlabeledMixed):
21+
data_extractor = data_extractor.labeled
1522
self.data_extractor = data_extractor
23+
1624
super().__init__()
1725

1826
def set_pos_weight(self, input):
@@ -27,16 +35,16 @@ def set_pos_weight(self, input):
2735
):
2836
complete_data = pd.concat(
2937
[
30-
pickle.load(
38+
pd.read_pickle(
3139
open(
3240
os.path.join(
3341
self.data_extractor.raw_dir,
34-
self.data_extractor.raw_file_names_dict[set],
42+
raw_file_name,
3543
),
3644
"rb",
3745
)
3846
)
39-
for set in ["train", "validation", "test"]
47+
for raw_file_name in self.data_extractor.raw_file_names
4048
]
4149
)
4250
value_counts = []

chebai/loss/semantic.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,29 @@
77
from typing import Literal
88

99
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
10+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
11+
from chebai.loss.bce_weighted import BCEWeighted
1012

1113

1214
class ImplicationLoss(torch.nn.Module):
1315
def __init__(
1416
self,
15-
data_extractor: _ChEBIDataExtractor,
17+
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
1618
base_loss: torch.nn.Module = None,
1719
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
1820
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
1921
pos_scalar=1,
2022
pos_epsilon=0.01,
23+
multiply_by_softmax=False,
2124
):
2225
super().__init__()
26+
# automatically choose labeled subset for implication filter in case of mixed dataset
27+
if isinstance(data_extractor, LabeledUnlabeledMixed):
28+
data_extractor = data_extractor.labeled
2329
self.data_extractor = data_extractor
30+
# propagate data_extractor to base loss
31+
if isinstance(base_loss, BCEWeighted):
32+
base_loss.data_extractor = self.data_extractor
2433
self.base_loss = base_loss
2534
self.implication_cache_file = f"implications_{self.data_extractor.name}.cache"
2635
self.label_names = _load_label_names(
@@ -36,6 +45,7 @@ def __init__(
3645
self.impl_weight = impl_loss_weight
3746
self.pos_scalar = pos_scalar
3847
self.eps = pos_epsilon
48+
self.multiply_by_softmax = multiply_by_softmax
3949

4050
def forward(self, input, target, **kwargs):
4151
nnl = kwargs.pop("non_null_labels", None)
@@ -70,16 +80,20 @@ def _calculate_implication_loss(self, l, r):
7080
math.pow(1 + self.eps, 1 / self.pos_scalar)
7181
- math.pow(self.eps, 1 / self.pos_scalar)
7282
)
73-
r = torch.pow(r, self.pos_scalar)
83+
one_min_r = torch.pow(1 - r, self.pos_scalar)
84+
else:
85+
one_min_r = 1 - r
7486
if self.tnorm == "product":
75-
individual_loss = l * (1 - r)
87+
individual_loss = l * one_min_r
7688
elif self.tnorm == "xu19":
77-
individual_loss = -torch.log(1 - l * (1 - r))
89+
individual_loss = -torch.log(1 - l * one_min_r)
7890
elif self.tnorm == "lukasiewicz":
79-
individual_loss = torch.relu(l - r)
91+
individual_loss = torch.relu(l + one_min_r - 1)
8092
else:
8193
raise NotImplementedError(f"Unknown tnorm {self.tnorm}")
8294

95+
if self.multiply_by_softmax:
96+
individual_loss = individual_loss * individual_loss.softmax(dim=-1)
8397
return torch.mean(
8498
torch.sum(individual_loss, dim=-1),
8599
dim=0,
@@ -100,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
100114
def __init__(
101115
self,
102116
path_to_disjointness,
103-
data_extractor: _ChEBIDataExtractor,
117+
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
104118
base_loss: torch.nn.Module = None,
105119
disjoint_loss_weight=100,
106120
**kwargs,

chebai/models/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ class ChebaiBaseNet(LightningModule):
2828
2929
Attributes:
3030
NAME (str): The name of the model.
31-
LOSS (torch.nn.Module): The loss function used by the model.
3231
"""
3332

3433
NAME = None
35-
LOSS = torch.nn.BCEWithLogitsLoss
3634

3735
def __init__(
3836
self,

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def setup_processed(self):
313313
def processed_file_names(self):
314314
raise NotImplementedError
315315

316+
@property
317+
def raw_file_names(self):
318+
raise NotImplementedError
319+
316320
@property
317321
def processed_file_names_dict(self) -> dict:
318322
raise NotImplementedError

chebai/preprocessing/datasets/chebi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def graph_to_raw_dataset(self, g, split_name=None):
192192
return data
193193

194194
def save_raw(self, data: pd.DataFrame, filename: str):
195-
pickle.dump(data, open(os.path.join(self.raw_dir, filename), "wb"))
195+
pd.to_pickle(data, open(os.path.join(self.raw_dir, filename), "wb"))
196196

197197
def _load_dict(self, input_file_path):
198198
"""
@@ -205,7 +205,7 @@ def _load_dict(self, input_file_path):
205205
dict: The dictionary, keys are `features`, `labels` and `ident`.
206206
"""
207207
with open(input_file_path, "rb") as input_file:
208-
df = pickle.load(input_file)
208+
df = pd.read_pickle(input_file)
209209
if self.single_class is not None:
210210
single_cls_index = list(df.columns).index(int(self.single_class))
211211
for row in df.values:
@@ -218,7 +218,7 @@ def _load_dict(self, input_file_path):
218218
@staticmethod
219219
def _get_data_size(input_file_path):
220220
with open(input_file_path, "rb") as f:
221-
return len(pickle.load(f))
221+
return len(pd.read_pickle(f))
222222

223223
def _setup_pruned_test_set(self):
224224
"""Create test set with same leaf nodes, but use classes that appear in train set"""
@@ -468,7 +468,7 @@ def prepare_data(self, *args, **kwargs):
468468
with open(
469469
os.path.join(self.raw_dir, self.raw_file_names_dict["test"]), "rb"
470470
) as input_file:
471-
test_df = pickle.load(input_file)
471+
test_df = pd.read_pickle(input_file)
472472
# create train/val split based on test set
473473
chebi_path = self._load_chebi(
474474
self.chebi_version_train

0 commit comments

Comments
 (0)