Skip to content

Commit dd45138

Browse files
authored
Merge pull request #25 from ChEB-AI/feature/semantic-losss
Feature/semantic losss
2 parents 16fe335 + 1523f4a commit dd45138

File tree

10 files changed

+655
-60
lines changed

10 files changed

+655
-60
lines changed

chebai/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
1919
parser.link_arguments(
2020
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
2121
)
22+
parser.link_arguments(
23+
"data", "model.init_args.criterion.init_args.data_extractor"
24+
)
2225

2326
@staticmethod
2427
def subcommands() -> Dict[str, Set[str]]:

chebai/loss/bce_weighted.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
2-
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
2+
from chebai.preprocessing.datasets.base import XYBaseDataModule
3+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
34
import pandas as pd
45
import os
56
import pickle
@@ -10,9 +11,16 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1011
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf)
1112
"""
1213

13-
def __init__(self, beta: float = None, data_extractor: _ChEBIDataExtractor = None):
14+
def __init__(
15+
self,
16+
beta: float = None,
17+
data_extractor: XYBaseDataModule = None,
18+
):
1419
self.beta = beta
20+
if isinstance(data_extractor, LabeledUnlabeledMixed):
21+
data_extractor = data_extractor.labeled
1522
self.data_extractor = data_extractor
23+
1624
super().__init__()
1725

1826
def set_pos_weight(self, input):
@@ -31,12 +39,12 @@ def set_pos_weight(self, input):
3139
open(
3240
os.path.join(
3341
self.data_extractor.raw_dir,
34-
self.data_extractor.raw_file_names_dict[set],
42+
raw_file_name,
3543
),
3644
"rb",
3745
)
3846
)
39-
for set in ["train", "validation", "test"]
47+
for raw_file_name in self.data_extractor.raw_file_names
4048
]
4149
)
4250
value_counts = []

chebai/loss/semantic.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,29 @@
77
from typing import Literal
88

99
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
10+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
11+
from chebai.loss.bce_weighted import BCEWeighted
1012

1113

1214
class ImplicationLoss(torch.nn.Module):
1315
def __init__(
1416
self,
15-
data_extractor: _ChEBIDataExtractor,
17+
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
1618
base_loss: torch.nn.Module = None,
1719
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
1820
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
1921
pos_scalar=1,
2022
pos_epsilon=0.01,
23+
multiply_by_softmax=False,
2124
):
2225
super().__init__()
26+
# automatically choose labeled subset for implication filter in case of mixed dataset
27+
if isinstance(data_extractor, LabeledUnlabeledMixed):
28+
data_extractor = data_extractor.labeled
2329
self.data_extractor = data_extractor
30+
# propagate data_extractor to base loss
31+
if isinstance(base_loss, BCEWeighted):
32+
base_loss.data_extractor = self.data_extractor
2433
self.base_loss = base_loss
2534
self.implication_cache_file = f"implications_{self.data_extractor.name}.cache"
2635
self.label_names = _load_label_names(
@@ -36,6 +45,7 @@ def __init__(
3645
self.impl_weight = impl_loss_weight
3746
self.pos_scalar = pos_scalar
3847
self.eps = pos_epsilon
48+
self.multiply_by_softmax = multiply_by_softmax
3949

4050
def forward(self, input, target, **kwargs):
4151
nnl = kwargs.pop("non_null_labels", None)
@@ -70,16 +80,20 @@ def _calculate_implication_loss(self, l, r):
7080
math.pow(1 + self.eps, 1 / self.pos_scalar)
7181
- math.pow(self.eps, 1 / self.pos_scalar)
7282
)
73-
r = torch.pow(r, self.pos_scalar)
83+
one_min_r = torch.pow(1 - r, self.pos_scalar)
84+
else:
85+
one_min_r = 1 - r
7486
if self.tnorm == "product":
75-
individual_loss = l * (1 - r)
87+
individual_loss = l * one_min_r
7688
elif self.tnorm == "xu19":
77-
individual_loss = -torch.log(1 - l * (1 - r))
89+
individual_loss = -torch.log(1 - l * one_min_r)
7890
elif self.tnorm == "lukasiewicz":
79-
individual_loss = torch.relu(l - r)
91+
individual_loss = torch.relu(l + one_min_r - 1)
8092
else:
8193
raise NotImplementedError(f"Unknown tnorm {self.tnorm}")
8294

95+
if self.multiply_by_softmax:
96+
individual_loss = individual_loss * individual_loss.softmax(dim=-1)
8397
return torch.mean(
8498
torch.sum(individual_loss, dim=-1),
8599
dim=0,
@@ -100,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
100114
def __init__(
101115
self,
102116
path_to_disjointness,
103-
data_extractor: _ChEBIDataExtractor,
117+
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
104118
base_loss: torch.nn.Module = None,
105119
disjoint_loss_weight=100,
106120
**kwargs,

chebai/models/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ class ChebaiBaseNet(LightningModule):
2828
2929
Attributes:
3030
NAME (str): The name of the model.
31-
LOSS (torch.nn.Module): The loss function used by the model.
3231
"""
3332

3433
NAME = None
35-
LOSS = torch.nn.BCEWithLogitsLoss
3634

3735
def __init__(
3836
self,

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ def setup_processed(self):
312312
def processed_file_names(self):
313313
raise NotImplementedError
314314

315+
@property
316+
def raw_file_names(self):
317+
raise NotImplementedError
318+
315319
@property
316320
def processed_file_names_dict(self) -> dict:
317321
raise NotImplementedError

0 commit comments

Comments
 (0)