1010class BCEWeighted (torch .nn .BCEWithLogitsLoss ):
1111 """
1212 BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
13- If beta is None or data_extractor is None, the loss is unweighted.
1413
1514 This class computes weights based on the formula from the paper by Cui et al. (2019):
1615 https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
@@ -33,19 +32,22 @@ def __init__(
3332 data_extractor = data_extractor .labeled
3433 self .data_extractor = data_extractor
3534
36- assert (
37- beta is not None
38- ), f"Beta parameter must be provided if this loss ({ self .__class__ .__name__ } ) is used."
35+ assert self .beta is not None and self .data_extractor is not None , (
36+ f"Beta parameter must be provided along with data_extractor, "
37+ f"if this loss class ({ self .__class__ .__name__ } ) is used."
38+ )
3939
40- # If beta is provided, require a data_extractor.
41- if self .beta is not None and self .data_extractor is None :
42- raise ValueError ("When 'beta' is set, 'data_extractor' must also be set." )
40+ assert all (
41+ os .path .exists (os .path .join (self .data_extractor .processed_dir , file_name ))
42+ for file_name in self .data_extractor .processed_file_names
43+ ), "Dataset files not found. Make sure the dataset is processed before using this loss."
4344
4445 assert (
4546 isinstance (self .data_extractor , _ChEBIDataExtractor )
4647 or self .data_extractor is None
4748 )
4849 super ().__init__ (** kwargs )
50+ self .pos_weight : Optional [torch .Tensor ] = None
4951
5052 def set_pos_weight (self , input : torch .Tensor ) -> None :
5153 """
@@ -54,17 +56,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5456 Args:
5557 input (torch.Tensor): The input tensor for which to set the positive weights.
5658 """
57- if (
58- self .beta is not None
59- and self .data_extractor is not None
60- and all (
61- os .path .exists (
62- os .path .join (self .data_extractor .processed_dir , file_name )
63- )
64- for file_name in self .data_extractor .processed_file_names
65- )
66- and self .pos_weight is None
67- ):
59+ if self .pos_weight is None :
6860 print (
6961 f"Computing loss-weights based on v{ self .data_extractor .chebi_version } dataset (beta={ self .beta } )"
7062 )
@@ -105,3 +97,9 @@ def forward(
10597 """
10698 self .set_pos_weight (input )
10799 return super ().forward (input , target )
100+
101+
102+ class UnWeightedBCEWithLogitsLoss (torch .nn .BCEWithLogitsLoss ):
103+ def forward (self , input , target , ** kwargs ):
104+ # As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
105+ return super ().forward (input , target )
0 commit comments