Skip to content

Commit 65c448b

Browse files
committed
bceweighted class should only be used if weighting is the intention, else another class is provided
1 parent 6c99441 commit 65c448b

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

chebai/loss/bce_weighted.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1111
"""
1212
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
13-
If beta is None or data_extractor is None, the loss is unweighted.
1413
1514
This class computes weights based on the formula from the paper by Cui et al. (2019):
1615
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
@@ -33,19 +32,22 @@ def __init__(
3332
data_extractor = data_extractor.labeled
3433
self.data_extractor = data_extractor
3534

36-
assert (
37-
beta is not None
38-
), f"Beta parameter must be provided if this loss ({self.__class__.__name__}) is used."
35+
assert self.beta is not None and self.data_extractor is not None, (
36+
f"Beta parameter must be provided along with data_extractor, "
37+
f"if this loss class ({self.__class__.__name__}) is used."
38+
)
3939

40-
# If beta is provided, require a data_extractor.
41-
if self.beta is not None and self.data_extractor is None:
42-
raise ValueError("When 'beta' is set, 'data_extractor' must also be set.")
40+
assert all(
41+
os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name))
42+
for file_name in self.data_extractor.processed_file_names
43+
), "Dataset files not found. Make sure the dataset is processed before using this loss."
4344

4445
assert (
4546
isinstance(self.data_extractor, _ChEBIDataExtractor)
4647
or self.data_extractor is None
4748
)
4849
super().__init__(**kwargs)
50+
self.pos_weight: Optional[torch.Tensor] = None
4951

5052
def set_pos_weight(self, input: torch.Tensor) -> None:
5153
"""
@@ -54,17 +56,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5456
Args:
5557
input (torch.Tensor): The input tensor for which to set the positive weights.
5658
"""
57-
if (
58-
self.beta is not None
59-
and self.data_extractor is not None
60-
and all(
61-
os.path.exists(
62-
os.path.join(self.data_extractor.processed_dir, file_name)
63-
)
64-
for file_name in self.data_extractor.processed_file_names
65-
)
66-
and self.pos_weight is None
67-
):
59+
if self.pos_weight is None:
6860
print(
6961
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
7062
)
@@ -105,3 +97,9 @@ def forward(
10597
"""
10698
self.set_pos_weight(input)
10799
return super().forward(input, target)
100+
101+
102+
class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss):
103+
def forward(self, input, target, **kwargs):
104+
# As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
105+
return super().forward(input, target)

0 commit comments

Comments
 (0)