1919
2020from monai .losses import AsymmetricUnifiedFocalLoss
2121
22+ logit_pos = 10.0
23+ logit_neg = - 10.0
24+
2225TEST_CASES = [
23- [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
26+ [ # Case 0: Binary segmentation
27+ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
28+ {
29+ "use_softmax" : False ,
30+ "include_background" : True ,
31+ },
2432 {
25- "y_pred" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
26- "y_true" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
33+ "y_pred" : torch .tensor (
34+ [[[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]], [[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]]]
35+ ),
36+ "y_true" : torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]], [[[1.0 , 0.0 ], [0.0 , 1.0 ]]]]),
2737 },
2838 0.0 ,
2939 ],
30- [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
40+ [ # Case 1: Multi-class segmentation with softmax
41+ # shape: (1, 3, 2, 2), (1, 3, 2, 2)
3142 {
32- "y_pred" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
33- "y_true" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
43+ "use_softmax" : True ,
44+ "include_background" : True ,
45+ },
46+ {
47+ "y_pred" : torch .tensor (
48+ [
49+ [
50+ [[logit_pos , logit_neg ], [logit_neg , logit_neg ]], # Class 0 (background)
51+ [[logit_neg , logit_pos ], [logit_neg , logit_neg ]], # Class 1
52+ [[logit_neg , logit_neg ], [logit_pos , logit_pos ]], # Class 2
53+ ]
54+ ]
55+ ),
56+ "y_true" : torch .tensor (
57+ [
58+ [
59+ [[1.0 , 0.0 ], [0.0 , 0.0 ]], # Class 0 (background)
60+ [[0.0 , 1.0 ], [0.0 , 0.0 ]], # Class 1
61+ [[0.0 , 0.0 ], [1.0 , 1.0 ]], # Class 2
62+ ]
63+ ]
64+ ),
65+ },
66+ 0.0 ,
67+ ],
68+ [ # Case 2: Multi-class with background excluded
69+ # shape: (1, 3, 2, 2), (1, 3, 2, 2)
70+ {"use_softmax" : True , "include_background" : False },
71+ {
72+ "y_pred" : torch .tensor (
73+ [
74+ [
75+ [[logit_pos , logit_neg ], [logit_neg , logit_neg ]], # Class 0 (background)
76+ [[logit_neg , logit_pos ], [logit_pos , logit_neg ]], # Class 1 (foreground)
77+ [[logit_neg , logit_neg ], [logit_neg , logit_pos ]], # Class 2 (foreground)
78+ ]
79+ ]
80+ ),
81+ "y_true" : torch .tensor (
82+ [
83+ [
84+ [[1.0 , 0.0 ], [0.0 , 0.0 ]], # Class 0 (background)
85+ [[0.0 , 1.0 ], [1.0 , 0.0 ]], # Class 1 (foreground)
86+ [[0.0 , 0.0 ], [0.0 , 1.0 ]], # Class 2 (foreground)
87+ ]
88+ ]
89+ ),
3490 },
3591 0.0 ,
3692 ],
4096class TestAsymmetricUnifiedFocalLoss (unittest .TestCase ):
4197
4298 @parameterized .expand (TEST_CASES )
43- def test_result (self , input_data , expected_val ):
44- loss = AsymmetricUnifiedFocalLoss ()
99+ def test_result (self , input_param , input_data , expected_val ):
100+ """
101+ Test AsymmetricUnifiedFocalLoss with various configurations.
102+
103+ Args:
104+ input_param: Dict of loss constructor parameters (use_softmax, include_background, etc.).
105+ input_data: Dict containing y_pred (logits) and y_true (ground truth) tensors.
106+ expected_val: Expected loss value.
107+ """
108+ loss = AsymmetricUnifiedFocalLoss (** input_param )
45109 result = loss (** input_data )
46110 np .testing .assert_allclose (result .detach ().cpu ().numpy (), expected_val , atol = 1e-4 , rtol = 1e-4 )
47111
@@ -52,8 +116,10 @@ def test_ill_shape(self):
52116
53117 def test_with_cuda (self ):
54118 loss = AsymmetricUnifiedFocalLoss ()
55- i = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]])
56- j = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]])
119+ i = torch .tensor (
120+ [[[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]], [[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]]]
121+ )
122+ j = torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]], [[[1.0 , 0.0 ], [0.0 , 1.0 ]]]])
57123 if torch .cuda .is_available ():
58124 i = i .cuda ()
59125 j = j .cuda ()
0 commit comments