2222
2323class TestFocalLoss (unittest .TestCase ):
2424 def test_consistency_with_cross_entropy_2d (self ):
25- # For gamma=0 the focal loss reduces to the cross entropy loss
26- focal_loss = FocalLoss (to_onehot_y = True , gamma = 0.0 , reduction = "mean" , weight = 1.0 )
27- ce = nn .CrossEntropyLoss (reduction = "mean" )
25+ """ For gamma=0 the focal loss reduces to the cross entropy loss"""
26+ focal_loss = FocalLoss (to_onehot_y = False , gamma = 0.0 , reduction = "mean" , weight = 1.0 )
27+ ce = nn .BCEWithLogitsLoss (reduction = "mean" )
2828 max_error = 0
2929 class_num = 10
3030 batch_size = 128
3131 for _ in range (100 ):
3232 # Create a random tensor of shape (batch_size, class_num, 8, 4)
3333 x = torch .rand (batch_size , class_num , 8 , 4 , requires_grad = True )
3434 # Create a random batch of classes
35- l = torch .randint (low = 0 , high = class_num , size = (batch_size , 1 , 8 , 4 ))
35+ l = torch .randint (low = 0 , high = 2 , size = (batch_size , class_num , 8 , 4 )). float ( )
3636 if torch .cuda .is_available ():
3737 x = x .cuda ()
3838 l = l .cuda ()
3939 output0 = focal_loss (x , l )
40- output1 = ce (x , l [:, 0 ]) / class_num
40+ output1 = ce (x , l )
4141 a = float (output0 .cpu ().detach ())
4242 b = float (output1 .cpu ().detach ())
4343 if abs (a - b ) > max_error :
4444 max_error = abs (a - b )
4545 self .assertAlmostEqual (max_error , 0.0 , places = 3 )
4646
4747 def test_consistency_with_cross_entropy_2d_onehot_label (self ):
48- # For gamma=0 the focal loss reduces to the cross entropy loss
49- focal_loss = FocalLoss (to_onehot_y = False , gamma = 0.0 , reduction = "mean" )
50- ce = nn .CrossEntropyLoss (reduction = "mean" )
48+ """ For gamma=0 the focal loss reduces to the cross entropy loss"""
49+ focal_loss = FocalLoss (to_onehot_y = True , gamma = 0.0 , reduction = "mean" )
50+ ce = nn .BCEWithLogitsLoss (reduction = "mean" )
5151 max_error = 0
5252 class_num = 10
5353 batch_size = 128
@@ -59,18 +59,18 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self):
5959 if torch .cuda .is_available ():
6060 x = x .cuda ()
6161 l = l .cuda ()
62- output0 = focal_loss (x , one_hot ( l , num_classes = class_num ) )
63- output1 = ce (x , l [:, 0 ]) / class_num
62+ output0 = focal_loss (x , l )
63+ output1 = ce (x , one_hot ( l , num_classes = class_num ))
6464 a = float (output0 .cpu ().detach ())
6565 b = float (output1 .cpu ().detach ())
6666 if abs (a - b ) > max_error :
6767 max_error = abs (a - b )
6868 self .assertAlmostEqual (max_error , 0.0 , places = 3 )
6969
7070 def test_consistency_with_cross_entropy_classification (self ):
71- # for gamma=0 the focal loss reduces to the cross entropy loss
71+ """ for gamma=0 the focal loss reduces to the cross entropy loss"""
7272 focal_loss = FocalLoss (to_onehot_y = True , gamma = 0.0 , reduction = "mean" )
73- ce = nn .CrossEntropyLoss (reduction = "mean" )
73+ ce = nn .BCEWithLogitsLoss (reduction = "mean" )
7474 max_error = 0
7575 class_num = 10
7676 batch_size = 128
@@ -84,19 +84,43 @@ def test_consistency_with_cross_entropy_classification(self):
8484 x = x .cuda ()
8585 l = l .cuda ()
8686 output0 = focal_loss (x , l )
87- output1 = ce (x , l [:, 0 ]) / class_num
87+ output1 = ce (x , one_hot ( l , num_classes = class_num ))
8888 a = float (output0 .cpu ().detach ())
8989 b = float (output1 .cpu ().detach ())
9090 if abs (a - b ) > max_error :
9191 max_error = abs (a - b )
9292 self .assertAlmostEqual (max_error , 0.0 , places = 3 )
9393
94+ def test_consistency_with_cross_entropy_classification_01 (self ):
95+ # for gamma=0.1 the focal loss differs from the cross entropy loss
96+ focal_loss = FocalLoss (to_onehot_y = True , gamma = 0.1 , reduction = "mean" )
97+ ce = nn .BCEWithLogitsLoss (reduction = "mean" )
98+ max_error = 0
99+ class_num = 10
100+ batch_size = 128
101+ for _ in range (100 ):
102+ # Create a random scores tensor of shape (batch_size, class_num)
103+ x = torch .rand (batch_size , class_num , requires_grad = True )
104+ # Create a random batch of classes
105+ l = torch .randint (low = 0 , high = class_num , size = (batch_size , 1 ))
106+ l = l .long ()
107+ if torch .cuda .is_available ():
108+ x = x .cuda ()
109+ l = l .cuda ()
110+ output0 = focal_loss (x , l )
111+ output1 = ce (x , one_hot (l , num_classes = class_num ))
112+ a = float (output0 .cpu ().detach ())
113+ b = float (output1 .cpu ().detach ())
114+ if abs (a - b ) > max_error :
115+ max_error = abs (a - b )
116+ self .assertNotAlmostEqual (max_error , 0.0 , places = 3 )
117+
94118 def test_bin_seg_2d (self ):
95119 # define 2d examples
96120 target = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 1 , 1 , 0 ], [0 , 1 , 1 , 0 ], [0 , 0 , 0 , 0 ]])
97121 # add another dimension corresponding to the batch (batch size = 1 here)
98122 target = target .unsqueeze (0 ) # shape (1, H, W)
99- pred_very_good = 1000 * F .one_hot (target , num_classes = 2 ).permute (0 , 3 , 1 , 2 ).float ()
123+ pred_very_good = 100 * F .one_hot (target , num_classes = 2 ).permute (0 , 3 , 1 , 2 ).float () - 50.0
100124
101125 # initialize the mean dice loss
102126 loss = FocalLoss (to_onehot_y = True )
@@ -112,7 +136,7 @@ def test_empty_class_2d(self):
112136 target = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ]])
113137 # add another dimension corresponding to the batch (batch size = 1 here)
114138 target = target .unsqueeze (0 ) # shape (1, H, W)
115- pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float ()
139+ pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float () - 500.0
116140
117141 # initialize the mean dice loss
118142 loss = FocalLoss (to_onehot_y = True )
@@ -128,7 +152,7 @@ def test_multi_class_seg_2d(self):
128152 target = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 1 , 2 , 0 ], [0 , 3 , 4 , 0 ], [0 , 0 , 0 , 0 ]])
129153 # add another dimension corresponding to the batch (batch size = 1 here)
130154 target = target .unsqueeze (0 ) # shape (1, H, W)
131- pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float ()
155+ pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float () - 500.0
132156 # initialize the mean dice loss
133157 loss = FocalLoss (to_onehot_y = True )
134158 loss_onehot = FocalLoss (to_onehot_y = False )
@@ -159,7 +183,7 @@ def test_bin_seg_3d(self):
159183 # add another dimension corresponding to the batch (batch size = 1 here)
160184 target = target .unsqueeze (0 ) # shape (1, H, W, D)
161185 target_one_hot = F .one_hot (target , num_classes = num_classes ).permute (0 , 4 , 1 , 2 , 3 ) # test one hot
162- pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 4 , 1 , 2 , 3 ).float ()
186+ pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 4 , 1 , 2 , 3 ).float () - 500.0
163187
164188 # initialize the mean dice loss
165189 loss = FocalLoss (to_onehot_y = True )
@@ -173,6 +197,19 @@ def test_bin_seg_3d(self):
173197 focal_loss_good = float (loss_onehot (pred_very_good , target_one_hot ).cpu ())
174198 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
175199
200+ def test_foreground (self ):
201+ background = torch .ones (1 , 1 , 5 , 5 )
202+ foreground = torch .zeros (1 , 1 , 5 , 5 )
203+ target = torch .cat ((background , foreground ), dim = 1 )
204+ input = torch .cat ((background , foreground ), dim = 1 )
205+ target [:, 0 , 2 , 2 ] = 0
206+ target [:, 1 , 2 , 2 ] = 1
207+
208+ fgbg = FocalLoss (to_onehot_y = False , include_background = True )(input , target )
209+ fg = FocalLoss (to_onehot_y = False , include_background = False )(input , target )
210+ self .assertAlmostEqual (float (fgbg .cpu ()), 0.1116 , places = 3 )
211+ self .assertAlmostEqual (float (fg .cpu ()), 0.1733 , places = 3 )
212+
176213 def test_ill_opts (self ):
177214 chn_input = torch .ones ((1 , 2 , 3 ))
178215 chn_target = torch .ones ((1 , 2 , 3 ))
@@ -182,7 +219,7 @@ def test_ill_opts(self):
182219 def test_ill_shape (self ):
183220 chn_input = torch .ones ((1 , 2 , 3 ))
184221 chn_target = torch .ones ((1 , 3 ))
185- with self .assertRaisesRegex (AssertionError , "" ):
222+ with self .assertRaisesRegex (ValueError , "" ):
186223 FocalLoss (reduction = "mean" )(chn_input , chn_target )
187224
188225 def test_ill_class_weight (self ):
0 commit comments