Skip to content

Commit 3279c10

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #157 from IFCA/feature-mmd-chunk-pairwise-distance
Allow chunk pairwise distance in MMD
2 parents afbc599 + 88e89c7 commit 3279c10

File tree

8 files changed

+225
-23
lines changed

8 files changed

+225
-23
lines changed

frouros/detectors/data_drift/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def X_ref(self, value: Optional[np.ndarray]) -> None: # noqa: N802
175175
self._check_array(X=value)
176176
self._X_ref = value
177177

178-
def fit(self, X: np.ndarray) -> Dict[str, Any]: # noqa: N803
178+
def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
179179
"""Fit detector.
180180
181181
:param X: feature data
@@ -186,7 +186,7 @@ def fit(self, X: np.ndarray) -> Dict[str, Any]: # noqa: N803
186186
self._check_fit_dimensions(X=X)
187187
for callback in self.callbacks: # type: ignore
188188
callback.on_fit_start()
189-
self._fit(X=X)
189+
self._fit(X=X, **kwargs)
190190
for callback in self.callbacks: # type: ignore
191191
callback.on_fit_end()
192192

frouros/detectors/data_drift/batch/distance_based/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _distance_measure(
108108
self,
109109
X_ref: np.ndarray, # noqa: N803
110110
X: np.ndarray, # noqa: N803
111+
**kwargs,
111112
) -> DistanceResult:
112113
pass
113114

@@ -166,6 +167,7 @@ def _distance_measure(
166167
self,
167168
X_ref: np.ndarray, # noqa: N803
168169
X: np.ndarray, # noqa: N803
170+
**kwargs,
169171
) -> DistanceResult:
170172
distance_bins = self._distance_measure_bins(X_ref=X_ref, X=X)
171173
distance = DistanceResult(distance=distance_bins)
@@ -186,7 +188,9 @@ def _calculate_bins_values(
186188

187189
@abc.abstractmethod
188190
def _distance_measure_bins(
189-
self, X_ref: np.ndarray, X: np.ndarray # noqa: N803
191+
self,
192+
X_ref: np.ndarray, # noqa: N803
193+
X: np.ndarray, # noqa: N803
190194
) -> float:
191195
pass
192196

@@ -246,6 +250,7 @@ def _distance_measure(
246250
self,
247251
X_ref: np.ndarray, # noqa: N803
248252
X: np.ndarray, # noqa: N803
253+
**kwargs,
249254
) -> DistanceResult:
250255
pass
251256

frouros/detectors/data_drift/batch/distance_based/emd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _distance_measure(
4545
self,
4646
X_ref: np.ndarray, # noqa: N803
4747
X: np.ndarray, # noqa: N803
48+
**kwargs,
4849
) -> DistanceResult:
4950
emd = self._emd(X=X_ref, Y=X, **self.kwargs)
5051
distance = DistanceResult(distance=emd)

frouros/detectors/data_drift/batch/distance_based/js.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
self,
2727
num_bins: int = 10,
2828
callbacks: Optional[Union[Callback, List[Callback]]] = None,
29-
**kwargs
29+
**kwargs,
3030
) -> None:
3131
"""Init method.
3232
@@ -50,6 +50,7 @@ def _distance_measure(
5050
self,
5151
X_ref: np.ndarray, # noqa: N803
5252
X: np.ndarray, # noqa: N803
53+
**kwargs,
5354
) -> DistanceResult:
5455
js = self._js(X=X_ref, Y=X, num_bins=self.num_bins, **self.kwargs)
5556
distance = DistanceResult(distance=js)
@@ -61,7 +62,7 @@ def _js(
6162
Y: np.ndarray,
6263
*,
6364
num_bins: int,
64-
**kwargs: Dict[str, Any]
65+
**kwargs: Dict[str, Any],
6566
) -> float:
6667
( # noqa: N806
6768
X_ref_rvs,

frouros/detectors/data_drift/batch/distance_based/kl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _distance_measure(
4747
self,
4848
X_ref: np.ndarray, # noqa: N803
4949
X: np.ndarray, # noqa: N803
50+
**kwargs,
5051
) -> DistanceResult:
5152
kl = self._kl(X=X_ref, Y=X, num_bins=self.num_bins, **self.kwargs)
5253
distance = DistanceResult(distance=kl)

frouros/detectors/data_drift/batch/distance_based/mmd.py

Lines changed: 160 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""MMD (Maximum Mean Discrepancy) module."""
22

3-
from typing import Callable, Optional, List, Union
3+
import itertools
4+
import math
5+
from typing import Callable, Iterator, Optional, List, Union
46

57
import numpy as np # type: ignore
68
from scipy.spatial.distance import cdist # type: ignore
9+
import tqdm # type: ignore
710

811
from frouros.callbacks import Callback
912
from frouros.detectors.data_drift.base import MultivariateData
@@ -43,12 +46,15 @@ class MMD(DistanceBasedBase):
4346
def __init__(
4447
self,
4548
kernel: Callable = rbf_kernel,
49+
chunk_size: Optional[int] = None,
4650
callbacks: Optional[Union[Callback, List[Callback]]] = None,
4751
) -> None:
4852
"""Init method.
4953
5054
:param kernel: kernel function to use
5155
:type kernel: Callable
56+
:param chunk_size:
57+
:type chunk_size: Optional[int]
5258
:param callbacks: callbacks
5359
:type callbacks: Optional[Union[Callback, List[Callback]]]
5460
"""
@@ -61,13 +67,42 @@ def __init__(
6167
callbacks=callbacks,
6268
)
6369
self.kernel = kernel
70+
self.chunk_size = chunk_size
71+
self._chunk_size_x = None
72+
self._expected_k_x = None
73+
self._X_num_samples = None
74+
75+
@property
76+
def chunk_size(self) -> Optional[int]:
77+
"""Chunk size property.
78+
79+
:return: chunk size to use
80+
:rtype: int
81+
"""
82+
return self._chunk_size
83+
84+
@chunk_size.setter
85+
def chunk_size(self, value: Optional[int]) -> None:
86+
"""Chunk size method setter.
87+
88+
:param value: value to be set
89+
:type value: Optional[int]
90+
:raises TypeError: Type error exception
91+
"""
92+
if value is not None:
93+
if isinstance(value, int): # type: ignore
94+
if value <= 0:
95+
raise ValueError("chunk_size must be greater than 0 or None.")
96+
else:
97+
raise TypeError("chunk_size must be of type int or None.")
98+
self._chunk_size = value
6499

65100
@property
66101
def kernel(self) -> Callable:
67102
"""Kernel property.
68103
69104
:return: kernel function to use
70-
:rtype: Kernel
105+
:rtype: Callable
71106
"""
72107
return self._kernel
73108

@@ -80,38 +115,147 @@ def kernel(self, value: Callable) -> None:
80115
:raises TypeError: Type error exception
81116
"""
82117
if not isinstance(value, Callable): # type: ignore
83-
raise TypeError("value must be of type Callable.")
118+
raise TypeError("kernel must be of type Callable.")
84119
self._kernel = value
85120

86121
def _distance_measure(
87122
self,
88123
X_ref: np.ndarray, # noqa: N803
89124
X: np.ndarray, # noqa: N803
125+
**kwargs,
90126
) -> DistanceResult:
91-
mmd = self._mmd(X=X_ref, Y=X, kernel=self.kernel)
127+
mmd = self._mmd(X=X_ref, Y=X, kernel=self.kernel, **kwargs)
92128
distance_test = DistanceResult(distance=mmd)
93129
return distance_test
94130

131+
def _fit(
132+
self,
133+
X: np.ndarray, # noqa: N803
134+
**kwargs,
135+
) -> None:
136+
super()._fit(X=X)
137+
# Add dimension only for the kernel calculation (if dim == 1)
138+
if X.ndim == 1:
139+
X = np.expand_dims(X, axis=1) # noqa: N806
140+
self._X_num_samples = len(X) # type: ignore # noqa: N806
141+
142+
self._chunk_size_x = (
143+
self._X_num_samples
144+
if self.chunk_size is None
145+
else self.chunk_size # type: ignore
146+
)
147+
148+
X_chunks = self._get_chunks( # noqa: N806
149+
data=X,
150+
chunk_size=self._chunk_size_x, # type: ignore
151+
)
152+
X_chunks_combinations = itertools.product(X_chunks, repeat=2) # noqa: N806
153+
154+
if kwargs.get("verbose", False):
155+
num_chunks = (
156+
math.ceil(self._X_num_samples / self._chunk_size_x) ** 2 # type: ignore
157+
)
158+
k_x_sum = np.array(
159+
[
160+
self.kernel(*X_chunk).sum()
161+
for X_chunk in tqdm.tqdm( # noqa: N806
162+
X_chunks_combinations, total=num_chunks # noqa: N806
163+
)
164+
]
165+
).sum()
166+
else:
167+
k_x_sum = np.array(
168+
[
169+
self.kernel(*X_chunk).sum()
170+
for X_chunk in X_chunks_combinations # noqa: N806
171+
]
172+
).sum()
173+
self._expected_k_x = k_x_sum / (
174+
self._X_num_samples * (self._X_num_samples - 1) # type: ignore
175+
)
176+
95177
@staticmethod
96-
def _mmd(
178+
def _get_chunks(data: np.ndarray, chunk_size: int) -> Iterator:
179+
chunks = (
180+
data[i : i + chunk_size] # noqa: E203
181+
for i in range(0, len(data), chunk_size)
182+
)
183+
return chunks
184+
185+
def _mmd( # pylint: disable=too-many-locals
186+
self,
97187
X: np.ndarray, # noqa: N803
98188
Y: np.ndarray,
99189
*,
100190
kernel: Callable,
191+
**kwargs,
101192
) -> float: # noqa: N803
102-
X_num_samples = X.shape[0] # noqa: N806
103-
Y_num_samples = Y.shape[0] # noqa: N806
104-
data = np.concatenate([X, Y]) # noqa: N806
193+
# Only check for X dimension (X == Y dim comparison has been already made)
105194
if X.ndim == 1:
106-
data = np.expand_dims(data, axis=1)
195+
X = np.expand_dims(X, axis=1) # noqa: N806
196+
Y = np.expand_dims(Y, axis=1) # noqa: N806
197+
198+
X_chunks = self._get_chunks( # noqa: N806
199+
data=X,
200+
chunk_size=self._chunk_size_x, # type: ignore
201+
)
202+
Y_num_samples = len(Y) # noqa: N806
203+
chunk_size_y = Y_num_samples if self.chunk_size is None else self.chunk_size
204+
Y_chunks, Y_chunks_copy = itertools.tee( # noqa: N806
205+
self._get_chunks(
206+
data=Y,
207+
chunk_size=chunk_size_y, # type: ignore
208+
),
209+
2,
210+
)
211+
Y_chunks_combinations = itertools.product( # noqa: N806
212+
Y_chunks,
213+
repeat=2,
214+
)
215+
XY_chunks_combinations = itertools.product( # noqa: N806
216+
X_chunks,
217+
Y_chunks_copy,
218+
)
219+
220+
if kwargs.get("verbose", False):
221+
num_chunks_y = math.ceil(Y_num_samples / self.chunk_size) # type: ignore
222+
num_chunks_y_combinations = num_chunks_y**2
223+
num_chunks_xy = (
224+
math.ceil(len(X) / self._chunk_size_x) * num_chunks_y # type: ignore
225+
)
226+
sum_y = np.array(
227+
[
228+
kernel(*Y_chunk).sum()
229+
for Y_chunk in tqdm.tqdm( # noqa: N806
230+
Y_chunks_combinations, total=num_chunks_y_combinations
231+
)
232+
]
233+
).sum()
234+
sum_xy = np.array(
235+
[
236+
kernel(*XY_chunk).sum()
237+
for XY_chunk in tqdm.tqdm( # noqa: N806
238+
XY_chunks_combinations, total=num_chunks_xy
239+
)
240+
]
241+
).sum()
242+
else:
243+
sum_y = np.array(
244+
[
245+
kernel(*Y_chunk).sum()
246+
for Y_chunk in Y_chunks_combinations # noqa: N806
247+
]
248+
).sum()
249+
sum_xy = np.array(
250+
[
251+
kernel(*XY_chunk).sum()
252+
for XY_chunk in XY_chunks_combinations # noqa: N806
253+
]
254+
).sum()
107255

108-
k_matrix = kernel(X=data, Y=data)
109-
k_x = k_matrix[:X_num_samples, :X_num_samples]
110-
k_y = k_matrix[Y_num_samples:, Y_num_samples:]
111-
k_xy = k_matrix[:X_num_samples, Y_num_samples:]
112256
mmd = (
113-
k_x.sum() / (X_num_samples * (X_num_samples - 1))
114-
+ k_y.sum() / (Y_num_samples * (Y_num_samples - 1))
115-
- 2 * k_xy.sum() / (X_num_samples * Y_num_samples)
257+
self._expected_k_x
258+
+ sum_y / (Y_num_samples * (Y_num_samples - 1))
259+
- 2 * sum_xy / (self._X_num_samples * Y_num_samples) # type: ignore
116260
)
117261
return mmd

frouros/tests/integration/test_data_drift.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Test data drift detectors."""
22

3-
from typing import Tuple
3+
from typing import Tuple, Union
44

55
import pytest # type: ignore
66
import numpy as np # type: ignore
@@ -243,6 +243,56 @@ def test_batch_distance_based_multivariate_same_distribution(
243243
assert np.isclose(statistic, expected_distance)
244244

245245

246+
@pytest.mark.parametrize(
247+
"detector, expected_distance",
248+
[(MMD(chunk_size=10), 0.12183835), (MMD(chunk_size=None), 0.12183835)],
249+
)
250+
def test_batch_distance_based_chunk_size_valid(
251+
X_ref_multivariate: np.ndarray, # noqa: N803
252+
X_test_multivariate: np.ndarray, # noqa: N803
253+
detector: DataDriftBatchBase,
254+
expected_distance: float,
255+
) -> None:
256+
"""Test batch distance based chunk size valid method.
257+
258+
:param X_ref_multivariate: reference multivariate data
259+
:type X_ref_multivariate: numpy.ndarray
260+
:param X_test_multivariate: test multivariate data
261+
:type X_test_multivariate: numpy.ndarray
262+
:param detector: detector test
263+
:type detector: DataDriftBatchBase
264+
:param expected_distance: expected distance value
265+
:type expected_distance: float
266+
"""
267+
_ = detector.fit(X=X_ref_multivariate)
268+
statistic, _ = detector.compare(X=X_test_multivariate)
269+
270+
assert np.isclose(statistic, expected_distance)
271+
272+
273+
@pytest.mark.parametrize(
274+
"chunk_size, expected_exception",
275+
[
276+
(1.5, TypeError),
277+
("10", TypeError),
278+
(-1, ValueError),
279+
],
280+
)
281+
def test_batch_distance_based_chunk_size_invalid(
282+
chunk_size: Union[int, float, str],
283+
expected_exception: Union[TypeError, ValueError],
284+
) -> None:
285+
"""Test batch distance based chunk size invalid method.
286+
287+
:param chunk_size: chunk size
288+
:type chunk_size: Union[int, float, str]
289+
:param expected_exception: expected exception
290+
:type expected_exception: Union[TypeError, ValueError]
291+
"""
292+
with pytest.raises(expected_exception):
293+
_ = MMD(chunk_size=chunk_size) # type: ignore
294+
295+
246296
@pytest.mark.parametrize(
247297
"detector, expected_statistic, expected_p_value",
248298
[

0 commit comments

Comments
 (0)