Skip to content

Commit 03b5fe7

Browse files
Add unit test MMD univariate data
1 parent 5f9c5ac commit 03b5fe7

File tree

5 files changed

+53
-0
lines changed

5 files changed

+53
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Detectors test init."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Data drift detectors test init."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Batch data drift detectors test init."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Distance based batch data drift detectors test init."""
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Test MMD."""
2+
3+
from functools import partial
4+
from typing import Tuple
5+
6+
import numpy as np # type: ignore
7+
import pytest # type: ignore
8+
9+
from frouros.detectors.data_drift import MMD
10+
from frouros.utils.kernels import rbf_kernel
11+
12+
13+
@pytest.mark.parametrize(
14+
"distribution_p, distribution_q, expected_distance",
15+
[
16+
((0, 1, 100), (0, 1, 100), 0.00052755), # (mean, std, size)
17+
((0, 1, 100), (0, 1, 10), -0.03200193),
18+
((0, 1, 10), (0, 1, 100), 0.07154671),
19+
((2, 1, 100), (0, 1, 100), 0.43377622),
20+
((2, 1, 100), (0, 1, 10), 0.23051378),
21+
((2, 1, 10), (0, 1, 100), 0.62530767),
22+
],
23+
)
24+
def test_mmd_batch_univariate(
25+
distribution_p: Tuple[float, float, int],
26+
distribution_q: Tuple[float, float, int],
27+
expected_distance: float,
28+
) -> None:
29+
"""Test MMD batch with univariate data.
30+
31+
:param distribution_p: mean, std and size of samples from distribution p
32+
:type distribution_p: Tuple[float, float, int]
33+
:param distribution_q: mean, std and size of samples from distribution q
34+
:type distribution_q: Tuple[float, float, int]
35+
:param expected_distance: expected distance value
36+
:type expected_distance: float
37+
"""
38+
np.random.seed(seed=31)
39+
X_ref = np.random.normal(*distribution_p) # noqa: N806
40+
X_test = np.random.normal(*distribution_q) # noqa: N806
41+
42+
detector = MMD(
43+
kernel=partial(rbf_kernel, sigma=0.5),
44+
)
45+
_ = detector.fit(X=X_ref)
46+
47+
result = detector.compare(X=X_test)[0]
48+
49+
assert np.isclose(result.distance, expected_distance)

0 commit comments

Comments
 (0)