Skip to content

Commit f137843

Browse files
Add unit test MMD precomputed
1 parent d2ef3e6 commit f137843

File tree

1 file changed

+56
-1
lines changed
  • frouros/tests/unit/detectors/data_drift/batch/distance_based

1 file changed

+56
-1
lines changed

frouros/tests/unit/detectors/data_drift/batch/distance_based/test_mmd.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test MMD."""
22

33
from functools import partial
4-
from typing import Tuple
4+
from typing import Optional, Tuple
55

66
import numpy as np # type: ignore
77
import pytest # type: ignore
@@ -47,3 +47,58 @@ def test_mmd_batch_univariate(
4747
result = detector.compare(X=X_test)[0]
4848

4949
assert np.isclose(result.distance, expected_distance)
50+
51+
52+
@pytest.mark.parametrize(
53+
"distribution_p, distribution_q, chunk_size",
54+
[
55+
((0, 1, 100), (0, 1, 100), None), # (mean, std, size)
56+
((0, 1, 100), (0, 1, 100), 2),
57+
((0, 1, 100), (0, 1, 100), 10),
58+
((0, 1, 100), (0, 1, 10), None),
59+
((0, 1, 100), (0, 1, 10), 2),
60+
((0, 1, 100), (0, 1, 10), 10),
61+
((0, 1, 10), (0, 1, 100), None),
62+
((0, 1, 10), (0, 1, 100), 2),
63+
((0, 1, 10), (0, 1, 100), 10),
64+
],
65+
)
66+
def test_mmd_batch_precomputed_expected_k_xx(
67+
distribution_p: Tuple[float, float, int],
68+
distribution_q: Tuple[float, float, int],
69+
chunk_size: Optional[int],
70+
) -> None:
71+
"""Test MMD batch with precomputed expected k_xx.
72+
73+
:param distribution_p: mean, std and size of samples from distribution p
74+
:type distribution_p: Tuple[float, float, int]
75+
:param distribution_q: mean, std and size of samples from distribution q
76+
:type distribution_q: Tuple[float, float, int]
77+
:param chunk_size: chunk size
78+
:type chunk_size: Optional[int]
79+
"""
80+
np.random.seed(seed=31)
81+
X_ref = np.random.normal(*distribution_p) # noqa: N806
82+
X_test = np.random.normal(*distribution_q) # noqa: N806
83+
84+
kernel = partial(rbf_kernel, sigma=0.5)
85+
86+
detector = MMD(
87+
kernel=kernel,
88+
chunk_size=chunk_size,
89+
)
90+
_ = detector.fit(X=X_ref)
91+
92+
# Computes mmd using precomputed expected k_xx
93+
precomputed_distance = detector.compare(X=X_test)[0].distance
94+
95+
# Computes mmd from scratch
96+
scratch_distance = MMD._mmd(
97+
X=X_ref,
98+
Y=X_test,
99+
kernel=kernel,
100+
chunk_size=chunk_size,
101+
)
102+
103+
assert np.isclose(precomputed_distance, scratch_distance)
104+

0 commit comments

Comments
 (0)