Skip to content

Commit 945abfe

Browse files
committed
fix: When alpha is a sequence, each alpha[c] should be interpreted as the weight for positive samples of class c. Negative samples should have a default weight of 1.0
Signed-off-by: ytl0623 <[email protected]>
1 parent dc19ec0 commit 945abfe

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

monai/losses/focal_loss.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
8383
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
8484
The value should be in [0, 1].
85-
If a sequence is provided, it must match the number of classes (after excluding background if set).
85+
If a sequence is provided, its length must match the number of classes (excluding the background class if `include_background=False`).
8686
Defaults to None.
8787
weight: weights to apply to the voxels of each class. If None no weights are applied.
8888
The input can be a single value (same weight for all classes), a sequence of values (the length
@@ -289,8 +289,10 @@ def sigmoid_focal_loss(
289289
# Reshape alpha for broadcasting: (1, C, 1, 1...)
290290
broadcast_dims = [-1] + [1] * len(target.shape[2:])
291291
alpha_t = alpha_t.view(broadcast_dims)
292-
# Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c
293-
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
292+
# Apply per-class weight only to positive samples
293+
# For positive samples (target==1): multiply by alpha[c]
294+
# For negative samples (target==0): keep weight as 1.0
295+
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))
294296

295297
loss = alpha_factor * loss
296298

tests/losses/test_focal_loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from tests.test_utils import TEST_DEVICES, test_script_save
2525

2626
TEST_CASES = []
27-
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
27+
for case in TEST_DEVICES:
28+
device = case[0]
2829
input_data = {
2930
"input": torch.tensor(
3031
[[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device
@@ -79,10 +80,10 @@
7980

8081
TEST_ALPHA_BROADCASTING = []
8182
for case in TEST_DEVICES:
82-
dev = case[0]
83+
device = case[0]
8384
for include_background in [True, False]:
8485
for use_softmax in [True, False]:
85-
TEST_ALPHA_BROADCASTING.append([dev, include_background, use_softmax])
86+
TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])
8687

8788

8889
class TestFocalLoss(unittest.TestCase):

0 commit comments

Comments
 (0)