Skip to content

Commit 3598aa2

Browse files
Fix PEP8
1 parent f137843 commit 3598aa2

File tree

2 files changed

+20
-25
lines changed
  • frouros
    • detectors/data_drift/batch/distance_based
    • tests/unit/detectors/data_drift/batch/distance_based

2 files changed

+20
-25
lines changed

frouros/detectors/data_drift/batch/distance_based/mmd.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""MMD (Maximum Mean Discrepancy) module."""
22

33
import itertools
4-
import math
54
from typing import Callable, Generator, Optional, List, Union
65

76
import numpy as np # type: ignore
8-
import tqdm # type: ignore
97

108
from frouros.callbacks.batch.base import BaseCallbackBatch
119
from frouros.detectors.data_drift.base import MultivariateData
@@ -137,7 +135,7 @@ def _fit(
137135
# Add dimension only for the kernel calculation (if dim == 1)
138136
if X.ndim == 1:
139137
X = np.expand_dims(X, axis=1) # noqa: N806
140-
x_num_samples = len(self.X_ref) # type: ignore # noqa: N806
138+
x_num_samples = len(self.X_ref) # type: ignore
141139

142140
chunk_size_x = (
143141
x_num_samples
@@ -147,7 +145,7 @@ def _fit(
147145

148146
x_chunks = self._get_chunks( # noqa: N806
149147
data=X,
150-
chunk_size=chunk_size_x, # type: ignore
148+
chunk_size=chunk_size_x,
151149
)
152150
x_chunks_combinations = itertools.product(x_chunks, repeat=2) # noqa: N806
153151

@@ -157,11 +155,11 @@ def _fit(
157155
kernel=self.kernel,
158156
)
159157
# Remove diagonal (j!=i case)
160-
- x_num_samples # type: ignore
158+
- x_num_samples
161159
)
162160

163161
self._expected_k_xx = k_xx_sum / ( # type: ignore
164-
x_num_samples * (x_num_samples - 1) # type: ignore
162+
x_num_samples * (x_num_samples - 1)
165163
)
166164

167165
@staticmethod
@@ -201,33 +199,31 @@ def _mmd( # pylint: disable=too-many-locals
201199
if "expected_k_xx" in kwargs:
202200
x_chunks_copy = MMD._get_chunks( # noqa: N806
203201
data=X,
204-
chunk_size=chunk_size_x, # type: ignore
202+
chunk_size=chunk_size_x,
205203
)
206204
expected_k_xx = kwargs["expected_k_xx"]
207205
else:
208206
# Compute expected_k_xx
209-
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
207+
x_chunks, x_chunks_copy = itertools.tee( # type: ignore
210208
MMD._get_chunks(
211209
data=X,
212-
chunk_size=chunk_size_x, # type: ignore
210+
chunk_size=chunk_size_x,
213211
),
214212
2,
215213
)
216-
x_chunks_combinations = itertools.product( # noqa: N806
214+
x_chunks_combinations = itertools.product( # type: ignore
217215
x_chunks,
218216
repeat=2,
219217
)
220218
k_xx_sum = (
221-
MMD._compute_kernel(
222-
chunk_combinations=x_chunks_combinations, # type: ignore
223-
kernel=kernel,
224-
)
225-
# Remove diagonal (j!=i case)
226-
- x_num_samples # type: ignore
227-
)
228-
expected_k_xx = k_xx_sum / ( # type: ignore
229-
x_num_samples * (x_num_samples - 1) # type: ignore
219+
MMD._compute_kernel(
220+
chunk_combinations=x_chunks_combinations, # type: ignore
221+
kernel=kernel,
222+
)
223+
# Remove diagonal (j!=i case)
224+
- x_num_samples
230225
)
226+
expected_k_xx = k_xx_sum / (x_num_samples * (x_num_samples - 1))
231227

232228
y_num_samples = len(Y) # noqa: N806
233229
chunk_size_y = (
@@ -238,7 +234,7 @@ def _mmd( # pylint: disable=too-many-locals
238234
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
239235
MMD._get_chunks(
240236
data=Y,
241-
chunk_size=chunk_size_y, # type: ignore
237+
chunk_size=chunk_size_y,
242238
),
243239
2,
244240
)
@@ -257,15 +253,15 @@ def _mmd( # pylint: disable=too-many-locals
257253
kernel=kernel,
258254
)
259255
# Remove diagonal (j!=i case)
260-
- y_num_samples # type: ignore
256+
- y_num_samples
261257
)
262258
k_xy_sum = MMD._compute_kernel(
263259
chunk_combinations=xy_chunks_combinations, # type: ignore
264260
kernel=kernel,
265261
)
266262
mmd = (
267-
+ expected_k_xx
263+
+expected_k_xx
268264
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
269-
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
265+
- 2 * k_xy_sum / (x_num_samples * y_num_samples)
270266
)
271267
return mmd

frouros/tests/unit/detectors/data_drift/batch/distance_based/test_mmd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,11 @@ def test_mmd_batch_precomputed_expected_k_xx(
9393
precomputed_distance = detector.compare(X=X_test)[0].distance
9494

9595
# Computes mmd from scratch
96-
scratch_distance = MMD._mmd(
96+
scratch_distance = MMD._mmd( # pylint: disable=protected-access
9797
X=X_ref,
9898
Y=X_test,
9999
kernel=kernel,
100100
chunk_size=chunk_size,
101101
)
102102

103103
assert np.isclose(precomputed_distance, scratch_distance)
104-

0 commit comments

Comments
 (0)