@@ -374,6 +374,47 @@ def test_script(self):
374374 test_input = torch .ones (2 , 2 , 8 , 8 )
375375 test_script_save (loss , test_input , test_input )
376376
377+ def test_alpha_sequence_broadcasting (self ):
378+ """
379+ Test FocalLoss with alpha as a sequence for proper broadcasting.
380+ """
381+ num_classes = 3
382+ alpha_seq = [0.1 , 0.5 , 2.0 ]
383+ batch_size = 2
384+ spatial_dims = (4 , 4 )
385+
386+ devices = ["cpu" , "cuda" ] if torch .cuda .is_available () else ["cpu" ]
387+
388+ for device in devices :
389+ logits = torch .randn (batch_size , num_classes , * spatial_dims , device = device )
390+ target = torch .randint (0 , num_classes , (batch_size , 1 , * spatial_dims ), device = device )
391+
392+ # Case 1: Softmax + Alpha Sequence
393+ loss_func_softmax = FocalLoss (
394+ to_onehot_y = True , gamma = 2.0 , alpha = alpha_seq , use_softmax = True , reduction = "mean"
395+ )
396+ loss_soft = loss_func_softmax (logits , target )
397+
398+ self .assertTrue (torch .is_tensor (loss_soft ))
399+ self .assertEqual (loss_soft .ndim , 0 )
400+ self .assertTrue (loss_soft > 0 , f"Softmax loss on { device } should be positive" )
401+
402+ # Case 2: Sigmoid + Alpha Sequence
403+ loss_func_sigmoid = FocalLoss (
404+ to_onehot_y = True , gamma = 2.0 , alpha = alpha_seq , use_softmax = False , reduction = "mean"
405+ )
406+ loss_sig = loss_func_sigmoid (logits , target )
407+
408+ self .assertTrue (torch .is_tensor (loss_sig ))
409+ self .assertEqual (loss_sig .ndim , 0 )
410+ self .assertTrue (loss_sig > 0 , f"Sigmoid loss on { device } should be positive" )
411+
412+ # Case 3: Error Handling (Mismatched alpha length)
413+ if device == devices [0 ]:
414+ wrong_alpha = [0.1 , 0.5 ]
415+ with self .assertRaisesRegex (ValueError , "length of alpha" ):
416+ FocalLoss (to_onehot_y = True , alpha = wrong_alpha , use_softmax = True )(logits , target )
417+
377418
378419if __name__ == "__main__" :
379420 unittest .main ()
0 commit comments