Skip to content

Commit 4f71e12

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #269 from IFCA/feature-precompute-mmd-ref
Add precompute kernel ref matrix values for MMD
2 parents 5f9c5ac + 3598aa2 commit 4f71e12

File tree

6 files changed

+178
-45
lines changed

6 files changed

+178
-45
lines changed

frouros/detectors/data_drift/batch/distance_based/mmd.py

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""MMD (Maximum Mean Discrepancy) module."""
22

33
import itertools
4-
import math
54
from typing import Callable, Generator, Optional, List, Union
65

76
import numpy as np # type: ignore
8-
import tqdm # type: ignore
97

108
from frouros.callbacks.batch.base import BaseCallbackBatch
119
from frouros.detectors.data_drift.base import MultivariateData
@@ -64,6 +62,7 @@ def __init__( # noqa: D107
6462
)
6563
self.kernel = kernel
6664
self.chunk_size = chunk_size
65+
self._expected_k_xx = None
6766

6867
@property
6968
def chunk_size(self) -> Optional[int]:
@@ -122,11 +121,47 @@ def _distance_measure(
122121
Y=X,
123122
kernel=self.kernel,
124123
chunk_size=self.chunk_size,
124+
expected_k_xx=self._expected_k_xx,
125125
**kwargs,
126126
)
127127
distance_test = DistanceResult(distance=mmd)
128128
return distance_test
129129

130+
def _fit(
131+
self,
132+
X: np.ndarray, # noqa: N803
133+
) -> None:
134+
super()._fit(X=X)
135+
# Add dimension only for the kernel calculation (if dim == 1)
136+
if X.ndim == 1:
137+
X = np.expand_dims(X, axis=1) # noqa: N806
138+
x_num_samples = len(self.X_ref) # type: ignore
139+
140+
chunk_size_x = (
141+
x_num_samples
142+
if self.chunk_size is None
143+
else self.chunk_size # type: ignore
144+
)
145+
146+
x_chunks = self._get_chunks( # noqa: N806
147+
data=X,
148+
chunk_size=chunk_size_x,
149+
)
150+
x_chunks_combinations = itertools.product(x_chunks, repeat=2) # noqa: N806
151+
152+
k_xx_sum = (
153+
self._compute_kernel(
154+
chunk_combinations=x_chunks_combinations, # type: ignore
155+
kernel=self.kernel,
156+
)
157+
# Remove diagonal (j!=i case)
158+
- x_num_samples
159+
)
160+
161+
self._expected_k_xx = k_xx_sum / ( # type: ignore
162+
x_num_samples * (x_num_samples - 1)
163+
)
164+
130165
@staticmethod
131166
def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float:
132167
k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum()
@@ -159,13 +194,37 @@ def _mmd( # pylint: disable=too-many-locals
159194
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
160195
else x_num_samples
161196
)
162-
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
163-
MMD._get_chunks(
197+
198+
# If expected_k_xx is provided, we don't need to compute it again
199+
if "expected_k_xx" in kwargs:
200+
x_chunks_copy = MMD._get_chunks( # noqa: N806
164201
data=X,
165-
chunk_size=chunk_size_x, # type: ignore
166-
),
167-
2,
168-
)
202+
chunk_size=chunk_size_x,
203+
)
204+
expected_k_xx = kwargs["expected_k_xx"]
205+
else:
206+
# Compute expected_k_xx
207+
x_chunks, x_chunks_copy = itertools.tee( # type: ignore
208+
MMD._get_chunks(
209+
data=X,
210+
chunk_size=chunk_size_x,
211+
),
212+
2,
213+
)
214+
x_chunks_combinations = itertools.product( # type: ignore
215+
x_chunks,
216+
repeat=2,
217+
)
218+
k_xx_sum = (
219+
MMD._compute_kernel(
220+
chunk_combinations=x_chunks_combinations, # type: ignore
221+
kernel=kernel,
222+
)
223+
# Remove diagonal (j!=i case)
224+
- x_num_samples
225+
)
226+
expected_k_xx = k_xx_sum / (x_num_samples * (x_num_samples - 1))
227+
169228
y_num_samples = len(Y) # noqa: N806
170229
chunk_size_y = (
171230
kwargs["chunk_size"]
@@ -175,14 +234,10 @@ def _mmd( # pylint: disable=too-many-locals
175234
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
176235
MMD._get_chunks(
177236
data=Y,
178-
chunk_size=chunk_size_y, # type: ignore
237+
chunk_size=chunk_size_y,
179238
),
180239
2,
181240
)
182-
x_chunks_combinations = itertools.product( # noqa: N806
183-
x_chunks,
184-
repeat=2,
185-
)
186241
y_chunks_combinations = itertools.product( # noqa: N806
187242
y_chunks,
188243
repeat=2,
@@ -192,50 +247,21 @@ def _mmd( # pylint: disable=too-many-locals
192247
y_chunks_copy,
193248
)
194249

195-
if kwargs.get("verbose", False):
196-
num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore
197-
num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore
198-
num_chunks_x_combinations = num_chunks_x**2
199-
num_chunks_y_combinations = num_chunks_y**2
200-
num_chunks_xy = (
201-
math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore
202-
)
203-
x_chunks_combinations = tqdm.tqdm(
204-
x_chunks_combinations,
205-
total=num_chunks_x_combinations,
206-
)
207-
y_chunks_combinations = tqdm.tqdm(
208-
y_chunks_combinations,
209-
total=num_chunks_y_combinations,
210-
)
211-
xy_chunks_combinations = tqdm.tqdm(
212-
xy_chunks_combinations,
213-
total=num_chunks_xy,
214-
)
215-
216-
k_xx_sum = (
217-
MMD._compute_kernel(
218-
chunk_combinations=x_chunks_combinations, # type: ignore
219-
kernel=kernel,
220-
)
221-
# Remove diagonal (j!=i case)
222-
- x_num_samples # type: ignore
223-
)
224250
k_yy_sum = (
225251
MMD._compute_kernel(
226252
chunk_combinations=y_chunks_combinations, # type: ignore
227253
kernel=kernel,
228254
)
229255
# Remove diagonal (j!=i case)
230-
- y_num_samples # type: ignore
256+
- y_num_samples
231257
)
232258
k_xy_sum = MMD._compute_kernel(
233259
chunk_combinations=xy_chunks_combinations, # type: ignore
234260
kernel=kernel,
235261
)
236262
mmd = (
237-
+k_xx_sum / (x_num_samples * (x_num_samples - 1))
263+
+expected_k_xx
238264
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
239-
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
265+
- 2 * k_xy_sum / (x_num_samples * y_num_samples)
240266
)
241267
return mmd
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Detectors test init."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Data drift detectors test init."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Batch data drift detectors test init."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Distance based batch data drift detectors test init."""
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Test MMD."""
2+
3+
from functools import partial
4+
from typing import Optional, Tuple
5+
6+
import numpy as np # type: ignore
7+
import pytest # type: ignore
8+
9+
from frouros.detectors.data_drift import MMD
10+
from frouros.utils.kernels import rbf_kernel
11+
12+
13+
@pytest.mark.parametrize(
14+
"distribution_p, distribution_q, expected_distance",
15+
[
16+
((0, 1, 100), (0, 1, 100), 0.00052755), # (mean, std, size)
17+
((0, 1, 100), (0, 1, 10), -0.03200193),
18+
((0, 1, 10), (0, 1, 100), 0.07154671),
19+
((2, 1, 100), (0, 1, 100), 0.43377622),
20+
((2, 1, 100), (0, 1, 10), 0.23051378),
21+
((2, 1, 10), (0, 1, 100), 0.62530767),
22+
],
23+
)
24+
def test_mmd_batch_univariate(
25+
distribution_p: Tuple[float, float, int],
26+
distribution_q: Tuple[float, float, int],
27+
expected_distance: float,
28+
) -> None:
29+
"""Test MMD batch with univariate data.
30+
31+
:param distribution_p: mean, std and size of samples from distribution p
32+
:type distribution_p: Tuple[float, float, int]
33+
:param distribution_q: mean, std and size of samples from distribution q
34+
:type distribution_q: Tuple[float, float, int]
35+
:param expected_distance: expected distance value
36+
:type expected_distance: float
37+
"""
38+
np.random.seed(seed=31)
39+
X_ref = np.random.normal(*distribution_p) # noqa: N806
40+
X_test = np.random.normal(*distribution_q) # noqa: N806
41+
42+
detector = MMD(
43+
kernel=partial(rbf_kernel, sigma=0.5),
44+
)
45+
_ = detector.fit(X=X_ref)
46+
47+
result = detector.compare(X=X_test)[0]
48+
49+
assert np.isclose(result.distance, expected_distance)
50+
51+
52+
@pytest.mark.parametrize(
53+
"distribution_p, distribution_q, chunk_size",
54+
[
55+
((0, 1, 100), (0, 1, 100), None), # (mean, std, size)
56+
((0, 1, 100), (0, 1, 100), 2),
57+
((0, 1, 100), (0, 1, 100), 10),
58+
((0, 1, 100), (0, 1, 10), None),
59+
((0, 1, 100), (0, 1, 10), 2),
60+
((0, 1, 100), (0, 1, 10), 10),
61+
((0, 1, 10), (0, 1, 100), None),
62+
((0, 1, 10), (0, 1, 100), 2),
63+
((0, 1, 10), (0, 1, 100), 10),
64+
],
65+
)
66+
def test_mmd_batch_precomputed_expected_k_xx(
67+
distribution_p: Tuple[float, float, int],
68+
distribution_q: Tuple[float, float, int],
69+
chunk_size: Optional[int],
70+
) -> None:
71+
"""Test MMD batch with precomputed expected k_xx.
72+
73+
:param distribution_p: mean, std and size of samples from distribution p
74+
:type distribution_p: Tuple[float, float, int]
75+
:param distribution_q: mean, std and size of samples from distribution q
76+
:type distribution_q: Tuple[float, float, int]
77+
:param chunk_size: chunk size
78+
:type chunk_size: Optional[int]
79+
"""
80+
np.random.seed(seed=31)
81+
X_ref = np.random.normal(*distribution_p) # noqa: N806
82+
X_test = np.random.normal(*distribution_q) # noqa: N806
83+
84+
kernel = partial(rbf_kernel, sigma=0.5)
85+
86+
detector = MMD(
87+
kernel=kernel,
88+
chunk_size=chunk_size,
89+
)
90+
_ = detector.fit(X=X_ref)
91+
92+
# Computes mmd using precomputed expected k_xx
93+
precomputed_distance = detector.compare(X=X_test)[0].distance
94+
95+
# Computes mmd from scratch
96+
scratch_distance = MMD._mmd( # pylint: disable=protected-access
97+
X=X_ref,
98+
Y=X_test,
99+
kernel=kernel,
100+
chunk_size=chunk_size,
101+
)
102+
103+
assert np.isclose(precomputed_distance, scratch_distance)

0 commit comments

Comments
 (0)