|
107 | 107 |
|
108 | 108 | TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]] |
109 | 109 |
|
110 | | -TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]] |
| 110 | +TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]] |
| 111 | + |
| 112 | +TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]] |
| 113 | + |
| 114 | +TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]] |
111 | 115 |
|
112 | 116 |
|
113 | 117 | class TestComputeMeanDice(unittest.TestCase): |
114 | 118 | # Functional part tests |
115 | | - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7]) |
| 119 | + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) |
116 | 120 | def test_value(self, input_data, expected_value): |
117 | 121 | result = compute_generalized_dice(**input_data) |
118 | 122 | np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) |
|
0 commit comments