Skip to content

Commit 21d827d

Browse files
Add MMD chunk size equivalence test
1 parent cfe72e5 commit 21d827d

File tree

1 file changed

+63
-4
lines changed
  • frouros/tests/unit/detectors/data_drift/batch/distance_based

1 file changed

+63
-4
lines changed

frouros/tests/unit/detectors/data_drift/batch/distance_based/test_mmd.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from frouros.detectors.data_drift import MMD # type: ignore
1010
from frouros.utils.kernels import rbf_kernel
1111

12+
RANDOM_SEED = 31
13+
DEFAULT_SIGMA = 0.5
14+
1215

1316
@pytest.mark.parametrize(
1417
"distribution_p, distribution_q, expected_distance",
@@ -35,12 +38,15 @@ def test_mmd_batch_univariate(
3538
:param expected_distance: expected distance value
3639
:type expected_distance: float
3740
"""
38-
np.random.seed(seed=31)
41+
np.random.seed(seed=RANDOM_SEED)
3942
X_ref = np.random.normal(*distribution_p) # noqa: N806
4043
X_test = np.random.normal(*distribution_q) # noqa: N806
4144

4245
detector = MMD(
43-
kernel=partial(rbf_kernel, sigma=0.5),
46+
kernel=partial(
47+
rbf_kernel,
48+
sigma=DEFAULT_SIGMA,
49+
),
4450
)
4551
_ = detector.fit(X=X_ref)
4652

@@ -77,11 +83,14 @@ def test_mmd_batch_precomputed_expected_k_xx(
7783
:param chunk_size: chunk size
7884
:type chunk_size: Optional[int]
7985
"""
80-
np.random.seed(seed=31)
86+
np.random.seed(seed=RANDOM_SEED)
8187
X_ref = np.random.normal(*distribution_p) # noqa: N806
8288
X_test = np.random.normal(*distribution_q) # noqa: N806
8389

84-
kernel = partial(rbf_kernel, sigma=0.5)
90+
kernel = partial(
91+
rbf_kernel,
92+
sigma=DEFAULT_SIGMA,
93+
)
8594

8695
detector = MMD(
8796
kernel=kernel,
@@ -101,3 +110,53 @@ def test_mmd_batch_precomputed_expected_k_xx(
101110
)
102111

103112
assert np.isclose(precomputed_distance, scratch_distance)
113+
114+
115+
@pytest.mark.parametrize(
116+
"distribution_p, distribution_q, chunk_size",
117+
[
118+
((0, 1, size), (2, 1, size), chunk_size)
119+
for size in [10, 100]
120+
for chunk_size in list(range(1, 11))
121+
],
122+
)
123+
def test_mmd_chunk_size_equivalence(
124+
distribution_p: Tuple[float, float, int],
125+
distribution_q: Tuple[float, float, int],
126+
chunk_size: int,
127+
) -> None:
128+
"""Test MMD with chunk_size=None vs specific chunk_size.
129+
130+
:param distribution_p: mean, std and size of samples from distribution p
131+
:type distribution_p: Tuple[float, float, int]
132+
:param distribution_q: mean, std and size of samples from distribution q
133+
:type distribution_q: Tuple[float, float, int]
134+
:param chunk_size: specific chunk size to compare with None
135+
:type chunk_size: int
136+
"""
137+
np.random.seed(seed=RANDOM_SEED)
138+
X_ref = np.random.normal(*distribution_p) # noqa: N806
139+
X_test = np.random.normal(*distribution_q) # noqa: N806
140+
141+
kernel = partial(
142+
rbf_kernel,
143+
sigma=DEFAULT_SIGMA,
144+
)
145+
146+
# Detector with chunk_size=None
147+
detector_none = MMD(
148+
kernel=kernel,
149+
chunk_size=None,
150+
)
151+
_ = detector_none.fit(X=X_ref)
152+
result_none = detector_none.compare(X=X_test)[0].distance
153+
154+
# Detector with specific chunk_size
155+
detector_chunk = MMD(
156+
kernel=kernel,
157+
chunk_size=chunk_size,
158+
)
159+
_ = detector_chunk.fit(X=X_ref)
160+
result_chunk = detector_chunk.compare(X=X_test)[0].distance
161+
162+
assert np.isclose(result_none, result_chunk)

0 commit comments

Comments
 (0)