Skip to content

Commit 6dc1536

Browse files
Fix _mmd from MMD to be an static method
1 parent ea25e65 commit 6dc1536

File tree

1 file changed

+50
-62
lines changed
  • frouros/detectors/data_drift/batch/distance_based

1 file changed

+50
-62
lines changed

frouros/detectors/data_drift/batch/distance_based/mmd.py

Lines changed: 50 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ def __init__(
5151
)
5252
self.kernel = kernel
5353
self.chunk_size = chunk_size
54-
self._chunk_size_x = None
55-
self.X_chunks_combinations = None
56-
self.X_num_samples = None
57-
self.expected_k_xx = None
5854

5955
@property
6056
def chunk_size(self) -> Optional[int]:
@@ -108,55 +104,16 @@ def _distance_measure(
108104
X: np.ndarray, # noqa: N803
109105
**kwargs,
110106
) -> DistanceResult:
111-
mmd = self._mmd(X=X_ref, Y=X, kernel=self.kernel, **kwargs)
107+
mmd = self._mmd(
108+
X=X_ref,
109+
Y=X,
110+
kernel=self.kernel,
111+
chunk_size=self.chunk_size,
112+
**kwargs,
113+
)
112114
distance_test = DistanceResult(distance=mmd)
113115
return distance_test
114116

115-
def _fit(
116-
self,
117-
X: np.ndarray, # noqa: N803
118-
**kwargs,
119-
) -> None:
120-
super()._fit(X=X)
121-
# Add dimension only for the kernel calculation (if dim == 1)
122-
if X.ndim == 1:
123-
X = np.expand_dims(X, axis=1) # noqa: N806
124-
self.X_num_samples = len(self.X_ref) # type: ignore # noqa: N806
125-
126-
self._chunk_size_x = (
127-
self.X_num_samples
128-
if self.chunk_size is None
129-
else self.chunk_size # type: ignore
130-
)
131-
132-
X_chunks = self._get_chunks( # noqa: N806
133-
data=X,
134-
chunk_size=self._chunk_size_x, # type: ignore
135-
)
136-
xx_chunks_combinations = itertools.product(X_chunks, repeat=2) # noqa: N806
137-
138-
if kwargs.get("verbose", False):
139-
num_chunks = (
140-
math.ceil(self.X_num_samples / self._chunk_size_x) ** 2 # type: ignore
141-
)
142-
xx_chunks_combinations = tqdm.tqdm(
143-
xx_chunks_combinations,
144-
total=num_chunks,
145-
)
146-
147-
k_xx_sum = (
148-
self._compute_kernel(
149-
chunk_combinations=xx_chunks_combinations, # type: ignore
150-
kernel=self.kernel,
151-
)
152-
# Remove diagonal (j!=i case)
153-
- self.X_num_samples # type: ignore
154-
)
155-
156-
self.expected_k_xx = k_xx_sum / ( # type: ignore
157-
self.X_num_samples * (self.X_num_samples - 1) # type: ignore
158-
)
159-
160117
@staticmethod
161118
def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float:
162119
k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum()
@@ -170,8 +127,8 @@ def _get_chunks(data: np.ndarray, chunk_size: int) -> Generator:
170127
)
171128
return chunks
172129

130+
@staticmethod
173131
def _mmd( # pylint: disable=too-many-locals
174-
self,
175132
X: np.ndarray, # noqa: N803
176133
Y: np.ndarray,
177134
*,
@@ -183,33 +140,56 @@ def _mmd( # pylint: disable=too-many-locals
183140
X = np.expand_dims(X, axis=1) # noqa: N806
184141
Y = np.expand_dims(Y, axis=1) # noqa: N806
185142

186-
X_chunks = self._get_chunks( # noqa: N806
187-
data=X,
188-
chunk_size=self._chunk_size_x, # type: ignore
143+
x_num_samples = len(X) # noqa: N806
144+
chunk_size_x = (
145+
kwargs["chunk_size"]
146+
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
147+
else x_num_samples
148+
)
149+
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
150+
MMD._get_chunks(
151+
data=X,
152+
chunk_size=chunk_size_x, # type: ignore
153+
),
154+
2,
189155
)
190156
y_num_samples = len(Y) # noqa: N806
191-
chunk_size_y = y_num_samples if self.chunk_size is None else self.chunk_size
157+
chunk_size_y = (
158+
kwargs["chunk_size"]
159+
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
160+
else y_num_samples
161+
)
192162
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
193-
self._get_chunks(
163+
MMD._get_chunks(
194164
data=Y,
195165
chunk_size=chunk_size_y, # type: ignore
196166
),
197167
2,
198168
)
169+
x_chunks_combinations = itertools.product( # noqa: N806
170+
x_chunks,
171+
repeat=2,
172+
)
199173
y_chunks_combinations = itertools.product( # noqa: N806
200174
y_chunks,
201175
repeat=2,
202176
)
203177
xy_chunks_combinations = itertools.product( # noqa: N806
204-
X_chunks,
178+
x_chunks_copy,
205179
y_chunks_copy,
206180
)
207181

208182
if kwargs.get("verbose", False):
183+
num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore
209184
num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore
185+
num_chunks_x_combinations = num_chunks_x**2
210186
num_chunks_y_combinations = num_chunks_y**2
211187
num_chunks_xy = (
212-
math.ceil(len(X) / self._chunk_size_x) * num_chunks_y # type: ignore
188+
math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore
189+
)
190+
x_chunks_combinations = tqdm.tqdm(
191+
x_chunks_combinations,
192+
total=num_chunks_x_combinations,
213193
)
214194
y_chunks_combinations = tqdm.tqdm(
215195
y_chunks_combinations,
@@ -220,21 +200,29 @@ def _mmd( # pylint: disable=too-many-locals
220200
total=num_chunks_xy,
221201
)
222202

203+
k_xx_sum = (
204+
MMD._compute_kernel(
205+
chunk_combinations=x_chunks_combinations, # type: ignore
206+
kernel=kernel,
207+
)
208+
# Remove diagonal (j!=i case)
209+
- x_num_samples # type: ignore
210+
)
223211
k_yy_sum = (
224-
self._compute_kernel(
212+
MMD._compute_kernel(
225213
chunk_combinations=y_chunks_combinations, # type: ignore
226214
kernel=kernel,
227215
)
228216
# Remove diagonal (j!=i case)
229217
- y_num_samples # type: ignore
230218
)
231-
k_xy_sum = self._compute_kernel(
219+
k_xy_sum = MMD._compute_kernel(
232220
chunk_combinations=xy_chunks_combinations, # type: ignore
233221
kernel=kernel,
234222
)
235223
mmd = (
236-
self.expected_k_xx # type: ignore
224+
+k_xx_sum / (x_num_samples * (x_num_samples - 1))
237225
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
238-
- 2 * k_xy_sum / (self.X_num_samples * y_num_samples) # type: ignore
226+
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
239227
)
240228
return mmd

0 commit comments

Comments
 (0)