2020
2121
2222class TestAUCMLoss (unittest .TestCase ):
23+ """Test cases for AUCMLoss."""
24+
2325 def test_v1 (self ):
26+ """Test AUCMLoss with version 'v1'."""
2427 loss_fn = AUCMLoss (version = "v1" )
2528 input = torch .randn (32 , 1 , requires_grad = True )
2629 target = torch .randint (0 , 2 , (32 , 1 )).float ()
@@ -29,6 +32,7 @@ def test_v1(self):
2932 self .assertEqual (loss .ndim , 0 )
3033
3134 def test_v2 (self ):
35+ """Test AUCMLoss with version 'v2'."""
3236 loss_fn = AUCMLoss (version = "v2" )
3337 input = torch .randn (32 , 1 , requires_grad = True )
3438 target = torch .randint (0 , 2 , (32 , 1 )).float ()
@@ -37,31 +41,36 @@ def test_v2(self):
3741 self .assertEqual (loss .ndim , 0 )
3842
3943 def test_invalid_version (self ):
44+ """Test that invalid version raises ValueError."""
4045 with self .assertRaises (ValueError ):
4146 AUCMLoss (version = "invalid" )
4247
4348 def test_invalid_input_shape (self ):
49+ """Test that invalid input shape raises ValueError."""
4450 loss_fn = AUCMLoss ()
4551 input = torch .randn (32 , 2 ) # Wrong channel
4652 target = torch .randint (0 , 2 , (32 , 1 )).float ()
4753 with self .assertRaises (ValueError ):
4854 loss_fn (input , target )
4955
5056 def test_invalid_target_shape (self ):
57+ """Test that invalid target shape raises ValueError."""
5158 loss_fn = AUCMLoss ()
5259 input = torch .randn (32 , 1 )
5360 target = torch .randint (0 , 2 , (32 , 2 )).float () # Wrong channel
5461 with self .assertRaises (ValueError ):
5562 loss_fn (input , target )
5663
5764 def test_shape_mismatch (self ):
65+ """Test that mismatched shapes raise ValueError."""
5866 loss_fn = AUCMLoss ()
5967 input = torch .randn (32 , 1 )
6068 target = torch .randint (0 , 2 , (16 , 1 )).float ()
6169 with self .assertRaises (ValueError ):
6270 loss_fn (input , target )
6371
6472 def test_backward (self ):
73+ """Test that gradients can be computed."""
6574 loss_fn = AUCMLoss ()
6675 input = torch .randn (32 , 1 , requires_grad = True )
6776 target = torch .randint (0 , 2 , (32 , 1 )).float ()
@@ -70,6 +79,7 @@ def test_backward(self):
7079 self .assertIsNotNone (input .grad )
7180
7281 def test_script_save (self ):
82+ """Test that the loss can be saved as TorchScript."""
7383 loss_fn = AUCMLoss ()
7484 test_script_save (loss_fn , torch .randn (32 , 1 ), torch .randint (0 , 2 , (32 , 1 )).float ())
7585
0 commit comments