Skip to content

Commit cf53eae

Browse files
committed
added num_classes parameter
Signed-off-by: ytl0623 <[email protected]>
1 parent 8368ef2 commit cf53eae

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
3434
def __init__(
3535
self,
3636
to_onehot_y: bool = False,
37+
num_classes: int = 2,
3738
delta: float = 0.7,
3839
gamma: float = 0.75,
3940
epsilon: float = 1e-7,
@@ -43,6 +44,7 @@ def __init__(
4344
"""
4445
Args:
4546
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
47+
num_classes: number of classes. Defaults to 2.
4648
delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
4749
gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
4850
epsilon: a small value to prevent division by zero. Defaults to 1e-7.
@@ -54,6 +56,7 @@ def __init__(
5456
"""
5557
super().__init__(reduction=LossReduction(reduction).value)
5658
self.to_onehot_y = to_onehot_y
59+
self.num_classes = num_classes
5760
self.delta = delta
5861
self.gamma = gamma
5962
self.epsilon = epsilon
@@ -72,7 +75,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7275
is_already_prob = True
7376

7477
if y_true.shape[1] == 1:
75-
y_true = one_hot(y_true, num_classes=2)
78+
y_true = one_hot(y_true, num_classes=self.num_classes)
7679
else:
7780
is_already_prob = False
7881

@@ -120,13 +123,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
120123

121124
# Apply reduction
122125
if self.reduction == LossReduction.MEAN.value:
123-
return torch.mean(all_losses)
126+
return torch.mean(torch.sum(all_losses, dim=1))
124127
elif self.reduction == LossReduction.SUM.value:
125128
return torch.sum(all_losses)
126129
elif self.reduction == LossReduction.NONE.value:
127130
return all_losses
128131
else:
129-
return torch.mean(all_losses)
132+
return torch.mean(torch.sum(all_losses, dim=1))
130133

131134

132135
class AsymmetricFocalLoss(_Loss):
@@ -143,6 +146,7 @@ class AsymmetricFocalLoss(_Loss):
143146
def __init__(
144147
self,
145148
to_onehot_y: bool = False,
149+
num_classes: int = 2,
146150
delta: float = 0.7,
147151
gamma: float = 2.0,
148152
epsilon: float = 1e-7,
@@ -152,6 +156,7 @@ def __init__(
152156
"""
153157
Args:
154158
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
159+
num_classes: number of classes. Defaults to 2.
155160
delta: weight for the foreground classes. Defaults to 0.7.
156161
gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
157162
epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
@@ -160,6 +165,7 @@ def __init__(
160165
"""
161166
super().__init__(reduction=LossReduction(reduction).value)
162167
self.to_onehot_y = to_onehot_y
168+
self.num_classes = num_classes
163169
self.delta = delta
164170
self.gamma = gamma
165171
self.epsilon = epsilon
@@ -172,12 +178,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
172178
y_true: ground truth labels.
173179
"""
174180

175-
if y_pred.shape[1] == 1 and not self.use_softmax:
181+
if y_pred.shape[1] == 1:
176182
y_pred = torch.sigmoid(y_pred)
177183
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
178184
is_already_prob = True
179185
if y_true.shape[1] == 1:
180-
y_true = one_hot(y_true, num_classes=2)
186+
y_true = one_hot(y_true, num_classes=self.num_classes)
181187
else:
182188
is_already_prob = False
183189

@@ -270,13 +276,15 @@ def __init__(
270276
self.use_softmax = use_softmax
271277

272278
self.asy_focal_loss = AsymmetricFocalLoss(
279+
num_classes=self.num_classes,
273280
gamma=self.gamma,
274281
delta=self.delta,
275282
use_softmax=self.use_softmax,
276283
to_onehot_y=to_onehot_y,
277284
reduction=LossReduction.NONE,
278285
)
279286
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
287+
num_classes=self.num_classes,
280288
gamma=self.gamma,
281289
delta=self.delta,
282290
use_softmax=self.use_softmax,

0 commit comments

Comments
 (0)