Skip to content

Commit 9601c64

Browse files
authored
Merge pull request #69 from ChEB-AI/feature-fuzzy-loss
Fuzzy loss (10/24)
2 parents 74f3ab9 + e7bd80a commit 9601c64

File tree

11 files changed

+964
-459
lines changed

11 files changed

+964
-459
lines changed

chebai/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
5050
parser.link_arguments(
5151
"data", "model.init_args.criterion.init_args.data_extractor"
5252
)
53+
parser.link_arguments(
54+
"data.init_args.chebi_version",
55+
"model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
56+
)
5357

5458
@staticmethod
5559
def subcommands() -> Dict[str, Set[str]]:

chebai/loss/bce_weighted.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
import torch
66

77
from chebai.preprocessing.datasets.base import XYBaseDataModule
8+
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
89
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
910

1011

1112
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1213
"""
1314
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
15+
If beta is None or data_extractor is None, the loss is unweighted.
1416
15-
This class computes weights based on the formula from the paper:
17+
This class computes weights based on the formula from the paper by Cui et al. (2019):
1618
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
1719
1820
Args:
@@ -24,13 +26,17 @@ def __init__(
2426
self,
2527
beta: Optional[float] = None,
2628
data_extractor: Optional[XYBaseDataModule] = None,
29+
**kwargs,
2730
):
2831
self.beta = beta
2932
if isinstance(data_extractor, LabeledUnlabeledMixed):
3033
data_extractor = data_extractor.labeled
3134
self.data_extractor = data_extractor
32-
33-
super().__init__()
35+
assert (
36+
isinstance(self.data_extractor, _ChEBIDataExtractor)
37+
or self.data_extractor is None
38+
)
39+
super().__init__(**kwargs)
3440

3541
def set_pos_weight(self, input: torch.Tensor) -> None:
3642
"""
@@ -50,6 +56,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5056
)
5157
and self.pos_weight is None
5258
):
59+
print(
60+
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
61+
)
5362
complete_data = pd.concat(
5463
[
5564
pd.read_pickle(
@@ -75,7 +84,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
7584
[w / mean for w in weights], device=input.device
7685
)
7786

78-
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
87+
def forward(
88+
self, input: torch.Tensor, target: torch.Tensor, **kwargs
89+
) -> torch.Tensor:
7990
"""
8091
Forward pass for the loss calculation.
8192

0 commit comments

Comments
 (0)