55import torch
66
77from chebai .preprocessing .datasets .base import XYBaseDataModule
8+ from chebai .preprocessing .datasets .chebi import _ChEBIDataExtractor
89from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
910
1011
1112class BCEWeighted (torch .nn .BCEWithLogitsLoss ):
1213 """
1314 BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
15+ If beta is None or data_extractor is None, the loss is unweighted.
1416
15- This class computes weights based on the formula from the paper:
17+ This class computes weights based on the formula from the paper by Cui et al. (2019) :
1618 https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
1719
1820 Args:
@@ -24,13 +26,17 @@ def __init__(
2426 self ,
2527 beta : Optional [float ] = None ,
2628 data_extractor : Optional [XYBaseDataModule ] = None ,
29+ ** kwargs ,
2730 ):
2831 self .beta = beta
2932 if isinstance (data_extractor , LabeledUnlabeledMixed ):
3033 data_extractor = data_extractor .labeled
3134 self .data_extractor = data_extractor
32-
33- super ().__init__ ()
35+ assert (
36+ isinstance (self .data_extractor , _ChEBIDataExtractor )
37+ or self .data_extractor is None
38+ )
39+ super ().__init__ (** kwargs )
3440
3541 def set_pos_weight (self , input : torch .Tensor ) -> None :
3642 """
@@ -50,6 +56,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5056 )
5157 and self .pos_weight is None
5258 ):
59+ print (
60+ f"Computing loss-weights based on v{ self .data_extractor .chebi_version } dataset (beta={ self .beta } )"
61+ )
5362 complete_data = pd .concat (
5463 [
5564 pd .read_pickle (
@@ -75,7 +84,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
7584 [w / mean for w in weights ], device = input .device
7685 )
7786
78- def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
87+ def forward (
88+ self , input : torch .Tensor , target : torch .Tensor , ** kwargs
89+ ) -> torch .Tensor :
7990 """
8091 Forward pass for the loss calculation.
8192
0 commit comments