|
1 | 1 | """Test MMD.""" |
2 | 2 |
|
3 | 3 | from functools import partial |
4 | | -from typing import Optional, Tuple |
| 4 | +from typing import Any, Optional, Tuple |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pytest |
@@ -160,3 +160,63 @@ def test_mmd_chunk_size_equivalence( |
160 | 160 | result_chunk = detector_chunk.compare(X=X_test)[0].distance |
161 | 161 |
|
162 | 162 | assert np.isclose(result_none, result_chunk) |
| 163 | + |
| 164 | + |
| 165 | +@pytest.mark.parametrize( |
| 166 | + "chunk_size", |
| 167 | + [ |
| 168 | + None, |
| 169 | + 1, |
| 170 | + 2, |
| 171 | + ], |
| 172 | +) |
| 173 | +def test_mmd_chunk_size_initialization_valid( |
| 174 | + chunk_size: Optional[int], |
| 175 | +) -> None: |
| 176 | + """Test MMD initialization with valid chunk sizes. |
| 177 | +
|
| 178 | + :param chunk_size: chunk size to test |
| 179 | + :type chunk_size: Optional[int] |
| 180 | + """ |
| 181 | + np.random.seed(seed=RANDOM_SEED) |
| 182 | + X_ref = np.random.normal(0, 1, 100) |
| 183 | + X_test = np.random.normal(0, 1, 100) |
| 184 | + |
| 185 | + kernel = partial(rbf_kernel, sigma=DEFAULT_SIGMA) |
| 186 | + |
| 187 | + detector = MMD( |
| 188 | + kernel=kernel, |
| 189 | + chunk_size=chunk_size, |
| 190 | + ) |
| 191 | + _ = detector.fit(X=X_ref) |
| 192 | + result = detector.compare(X=X_test)[0] |
| 193 | + |
| 194 | + assert result is not None |
| 195 | + |
| 196 | + |
| 197 | +@pytest.mark.parametrize( |
| 198 | + "chunk_size", |
| 199 | + [ |
| 200 | + 0, |
| 201 | + -1, |
| 202 | + "invalid", |
| 203 | + 1.5, |
| 204 | + [1, 2], |
| 205 | + {1: 2}, |
| 206 | + ], |
| 207 | +) |
| 208 | +def test_mmd_chunk_size_invalid( |
| 209 | + chunk_size: Any, |
| 210 | +) -> None: |
| 211 | + """Test MMD initialization with invalid chunk sizes. |
| 212 | +
|
| 213 | + :param chunk_size: chunk size to test |
| 214 | + :type chunk_size: Any |
| 215 | + """ |
| 216 | + kernel = partial(rbf_kernel, sigma=0.5) |
| 217 | + |
| 218 | + with pytest.raises((TypeError, ValueError)): |
| 219 | + MMD( |
| 220 | + kernel=kernel, |
| 221 | + chunk_size=chunk_size, |
| 222 | + ) |
0 commit comments