2020from monai .losses import AsymmetricUnifiedFocalLoss
2121
2222# 1. Binary Case (Logits input): Prediction matches GT perfectly
23- # Input Shape: (B, 1, H, W) -> Auto expanded internally
2423TEST_CASE_BINARY_LOGITS = [
2524 {"y_pred" : torch .tensor ([[[[10.0 , - 10.0 ], [- 10.0 , 10.0 ]]]]), "y_true" : torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]])},
2625 0.0 ,
27- {"use_softmax" : False , "to_onehot_y" : False },
26+ {"use_softmax" : False , "to_onehot_y" : False , "num_classes" : 2 },
2827]
2928
3029# 2. Binary Case (2 Channels input): Prediction matches GT perfectly
31- # Input Shape: (B, 2, H, W)
3230TEST_CASE_BINARY_2CH = [
3331 {
34- "y_pred" : torch .tensor (
35- [[[[- 10.0 , 10.0 ], [10.0 , - 10.0 ]], [[10.0 , - 10.0 ], [- 10.0 , 10.0 ]]]] # Ch0 (Background): Low, High, High, Low
36- ), # Ch1 (Foreground): High, Low, Low, High
32+ "y_pred" : torch .tensor ([[[[- 10.0 , 10.0 ], [10.0 , - 10.0 ]], [[10.0 , - 10.0 ], [- 10.0 , 10.0 ]]]]),
3733 "y_true" : torch .tensor ([[[[1 , 0 ], [0 , 1 ]]]]),
3834 },
3935 0.0 ,
40- {"use_softmax" : True , "to_onehot_y" : True },
36+ {"use_softmax" : True , "to_onehot_y" : True , "num_classes" : 2 },
4137]
4238
4339# 3. Multi-Class Case (3 Channels): Prediction matches GT perfectly
4642 "y_pred" : torch .tensor (
4743 [
4844 [
49- [[10.0 , - 10.0 ], [- 10.0 , 10.0 ]], # Class 0 Logits
50- [[- 10.0 , 10.0 ], [- 10.0 , - 10.0 ]], # Class 1 Logits
51- [[- 10.0 , - 10.0 ], [10.0 , - 10.0 ]],
45+ [[10.0 , - 10.0 ], [- 10.0 , 10.0 ]], # Class 0
46+ [[- 10.0 , 10.0 ], [- 10.0 , - 10.0 ]], # Class 1
47+ [[- 10.0 , - 10.0 ], [10.0 , - 10.0 ]], # Class 2
5248 ]
5349 ]
54- ), # Class 2 Logits
55- "y_true" : torch .tensor ([[[[0 , 1 ], [2 , 0 ]]]]), # Indices
50+ ),
51+ "y_true" : torch .tensor ([[[[0 , 1 ], [2 , 0 ]]]]),
5652 },
5753 0.0 ,
58- {"use_softmax" : True , "to_onehot_y" : True },
54+ {"use_softmax" : True , "to_onehot_y" : True , "num_classes" : 3 },
5955]
6056
6157# 4. Multi-Class Case: Wrong Prediction
6460 "y_pred" : torch .tensor (
6561 [[[[- 10.0 , - 10.0 ], [- 10.0 , - 10.0 ]], [[10.0 , 10.0 ], [10.0 , 10.0 ]], [[- 10.0 , - 10.0 ], [- 10.0 , - 10.0 ]]]]
6662 ),
67- "y_true" : torch .tensor ([[[[0 , 0 ], [0 , 0 ]]]]), # GT is class 0, but Pred is class 1
63+ "y_true" : torch .tensor ([[[[0 , 0 ], [0 , 0 ]]]]),
6864 },
6965 None ,
70- {"use_softmax" : True , "to_onehot_y" : True },
66+ {"use_softmax" : True , "to_onehot_y" : True , "num_classes" : 3 },
7167]
7268
7369
@@ -77,11 +73,11 @@ class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
7773 def test_perfect_prediction (self , input_data , expected_val , args ):
7874 loss_func = AsymmetricUnifiedFocalLoss (** args )
7975 result = loss_func (** input_data )
80- # We use a small tolerance because 10.0 logits is not exactly probability 1.0
76+ # Using a relaxed tolerance for logits -> probability conversion
8177 np .testing .assert_allclose (result .detach ().cpu ().numpy (), expected_val , atol = 1e-3 , rtol = 1e-3 )
8278
8379 @parameterized .expand ([TEST_CASE_MULTICLASS_WRONG ])
84- def test_wrong_prediction (self , input_data , expected_val , args ):
80+ def test_wrong_prediction (self , input_data , _ , args ):
8581 loss_func = AsymmetricUnifiedFocalLoss (** args )
8682 result = loss_func (** input_data )
8783 self .assertGreater (result .item (), 1.0 , "Loss should be high for wrong predictions" )
@@ -93,7 +89,6 @@ def test_ill_shape(self):
9389
9490 def test_with_cuda (self ):
9591 if not torch .cuda .is_available ():
96- print ("CUDA not available, skipping test_with_cuda" )
9792 return
9893
9994 loss = AsymmetricUnifiedFocalLoss (use_softmax = False , to_onehot_y = False )
@@ -102,7 +97,6 @@ def test_with_cuda(self):
10297 j = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]).cuda ()
10398
10499 output = loss (i , j )
105- print (f"CUDA Output: { output .item ()} " )
106100 self .assertTrue (output .is_cuda )
107101 self .assertLess (output .item (), 1.0 )
108102
0 commit comments