@@ -34,6 +34,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
3434 def __init__ (
3535 self ,
3636 to_onehot_y : bool = False ,
37+ num_classes : int = 2 ,
3738 delta : float = 0.7 ,
3839 gamma : float = 0.75 ,
3940 epsilon : float = 1e-7 ,
@@ -43,6 +44,7 @@ def __init__(
4344 """
4445 Args:
4546 to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
47+ num_classes: number of classes. Defaults to 2.
4648 delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
4749 gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
4850 epsilon: a small value to prevent division by zero. Defaults to 1e-7.
@@ -54,6 +56,7 @@ def __init__(
5456 """
5557 super ().__init__ (reduction = LossReduction (reduction ).value )
5658 self .to_onehot_y = to_onehot_y
59+ self .num_classes = num_classes
5760 self .delta = delta
5861 self .gamma = gamma
5962 self .epsilon = epsilon
@@ -72,7 +75,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7275 is_already_prob = True
7376
7477 if y_true .shape [1 ] == 1 :
75- y_true = one_hot (y_true , num_classes = 2 )
78+ y_true = one_hot (y_true , num_classes = self . num_classes )
7679 else :
7780 is_already_prob = False
7881
@@ -120,13 +123,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
120123
121124 # Apply reduction
122125 if self .reduction == LossReduction .MEAN .value :
123- return torch .mean (all_losses )
126+ return torch .mean (torch . sum ( all_losses , dim = 1 ) )
124127 elif self .reduction == LossReduction .SUM .value :
125128 return torch .sum (all_losses )
126129 elif self .reduction == LossReduction .NONE .value :
127130 return all_losses
128131 else :
129- return torch .mean (all_losses )
132+ return torch .mean (torch . sum ( all_losses , dim = 1 ) )
130133
131134
132135class AsymmetricFocalLoss (_Loss ):
@@ -143,6 +146,7 @@ class AsymmetricFocalLoss(_Loss):
143146 def __init__ (
144147 self ,
145148 to_onehot_y : bool = False ,
149+ num_classes : int = 2 ,
146150 delta : float = 0.7 ,
147151 gamma : float = 2.0 ,
148152 epsilon : float = 1e-7 ,
@@ -152,6 +156,7 @@ def __init__(
152156 """
153157 Args:
154158 to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
159+ num_classes: number of classes. Defaults to 2.
155160 delta: weight for the foreground classes. Defaults to 0.7.
156161 gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
157162 epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
@@ -160,6 +165,7 @@ def __init__(
160165 """
161166 super ().__init__ (reduction = LossReduction (reduction ).value )
162167 self .to_onehot_y = to_onehot_y
168+ self .num_classes = num_classes
163169 self .delta = delta
164170 self .gamma = gamma
165171 self .epsilon = epsilon
@@ -172,12 +178,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
172178 y_true: ground truth labels.
173179 """
174180
175- if y_pred .shape [1 ] == 1 and not self . use_softmax :
181+ if y_pred .shape [1 ] == 1 :
176182 y_pred = torch .sigmoid (y_pred )
177183 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
178184 is_already_prob = True
179185 if y_true .shape [1 ] == 1 :
180- y_true = one_hot (y_true , num_classes = 2 )
186+ y_true = one_hot (y_true , num_classes = self . num_classes )
181187 else :
182188 is_already_prob = False
183189
@@ -270,13 +276,15 @@ def __init__(
270276 self .use_softmax = use_softmax
271277
272278 self .asy_focal_loss = AsymmetricFocalLoss (
279+ num_classes = self .num_classes ,
273280 gamma = self .gamma ,
274281 delta = self .delta ,
275282 use_softmax = self .use_softmax ,
276283 to_onehot_y = to_onehot_y ,
277284 reduction = LossReduction .NONE ,
278285 )
279286 self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
287+ num_classes = self .num_classes ,
280288 gamma = self .gamma ,
281289 delta = self .delta ,
282290 use_softmax = self .use_softmax ,
0 commit comments