Skip to content

Commit d2ef3e6

Browse files
Precompute kernel matrix of reference data in MMD
1 parent 03b5fe7 commit d2ef3e6

File tree

1 file changed

+69
-39
lines changed
  • frouros/detectors/data_drift/batch/distance_based

1 file changed

+69
-39
lines changed

frouros/detectors/data_drift/batch/distance_based/mmd.py

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__( # noqa: D107
6464
)
6565
self.kernel = kernel
6666
self.chunk_size = chunk_size
67+
self._expected_k_xx = None
6768

6869
@property
6970
def chunk_size(self) -> Optional[int]:
@@ -122,11 +123,47 @@ def _distance_measure(
122123
Y=X,
123124
kernel=self.kernel,
124125
chunk_size=self.chunk_size,
126+
expected_k_xx=self._expected_k_xx,
125127
**kwargs,
126128
)
127129
distance_test = DistanceResult(distance=mmd)
128130
return distance_test
129131

132+
def _fit(
133+
self,
134+
X: np.ndarray, # noqa: N803
135+
) -> None:
136+
super()._fit(X=X)
137+
# Add dimension only for the kernel calculation (if dim == 1)
138+
if X.ndim == 1:
139+
X = np.expand_dims(X, axis=1) # noqa: N806
140+
x_num_samples = len(self.X_ref) # type: ignore # noqa: N806
141+
142+
chunk_size_x = (
143+
x_num_samples
144+
if self.chunk_size is None
145+
else self.chunk_size # type: ignore
146+
)
147+
148+
x_chunks = self._get_chunks( # noqa: N806
149+
data=X,
150+
chunk_size=chunk_size_x, # type: ignore
151+
)
152+
x_chunks_combinations = itertools.product(x_chunks, repeat=2) # noqa: N806
153+
154+
k_xx_sum = (
155+
self._compute_kernel(
156+
chunk_combinations=x_chunks_combinations, # type: ignore
157+
kernel=self.kernel,
158+
)
159+
# Remove diagonal (j!=i case)
160+
- x_num_samples # type: ignore
161+
)
162+
163+
self._expected_k_xx = k_xx_sum / ( # type: ignore
164+
x_num_samples * (x_num_samples - 1) # type: ignore
165+
)
166+
130167
@staticmethod
131168
def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float:
132169
k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum()
@@ -159,13 +196,39 @@ def _mmd( # pylint: disable=too-many-locals
159196
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
160197
else x_num_samples
161198
)
162-
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
163-
MMD._get_chunks(
199+
200+
# If expected_k_xx is provided, we don't need to compute it again
201+
if "expected_k_xx" in kwargs:
202+
x_chunks_copy = MMD._get_chunks( # noqa: N806
164203
data=X,
165204
chunk_size=chunk_size_x, # type: ignore
166-
),
167-
2,
168-
)
205+
)
206+
expected_k_xx = kwargs["expected_k_xx"]
207+
else:
208+
# Compute expected_k_xx
209+
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
210+
MMD._get_chunks(
211+
data=X,
212+
chunk_size=chunk_size_x, # type: ignore
213+
),
214+
2,
215+
)
216+
x_chunks_combinations = itertools.product( # noqa: N806
217+
x_chunks,
218+
repeat=2,
219+
)
220+
k_xx_sum = (
221+
MMD._compute_kernel(
222+
chunk_combinations=x_chunks_combinations, # type: ignore
223+
kernel=kernel,
224+
)
225+
# Remove diagonal (j!=i case)
226+
- x_num_samples # type: ignore
227+
)
228+
expected_k_xx = k_xx_sum / ( # type: ignore
229+
x_num_samples * (x_num_samples - 1) # type: ignore
230+
)
231+
169232
y_num_samples = len(Y) # noqa: N806
170233
chunk_size_y = (
171234
kwargs["chunk_size"]
@@ -179,10 +242,6 @@ def _mmd( # pylint: disable=too-many-locals
179242
),
180243
2,
181244
)
182-
x_chunks_combinations = itertools.product( # noqa: N806
183-
x_chunks,
184-
repeat=2,
185-
)
186245
y_chunks_combinations = itertools.product( # noqa: N806
187246
y_chunks,
188247
repeat=2,
@@ -192,35 +251,6 @@ def _mmd( # pylint: disable=too-many-locals
192251
y_chunks_copy,
193252
)
194253

195-
if kwargs.get("verbose", False):
196-
num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore
197-
num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore
198-
num_chunks_x_combinations = num_chunks_x**2
199-
num_chunks_y_combinations = num_chunks_y**2
200-
num_chunks_xy = (
201-
math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore
202-
)
203-
x_chunks_combinations = tqdm.tqdm(
204-
x_chunks_combinations,
205-
total=num_chunks_x_combinations,
206-
)
207-
y_chunks_combinations = tqdm.tqdm(
208-
y_chunks_combinations,
209-
total=num_chunks_y_combinations,
210-
)
211-
xy_chunks_combinations = tqdm.tqdm(
212-
xy_chunks_combinations,
213-
total=num_chunks_xy,
214-
)
215-
216-
k_xx_sum = (
217-
MMD._compute_kernel(
218-
chunk_combinations=x_chunks_combinations, # type: ignore
219-
kernel=kernel,
220-
)
221-
# Remove diagonal (j!=i case)
222-
- x_num_samples # type: ignore
223-
)
224254
k_yy_sum = (
225255
MMD._compute_kernel(
226256
chunk_combinations=y_chunks_combinations, # type: ignore
@@ -234,7 +264,7 @@ def _mmd( # pylint: disable=too-many-locals
234264
kernel=kernel,
235265
)
236266
mmd = (
237-
+k_xx_sum / (x_num_samples * (x_num_samples - 1))
267+
+ expected_k_xx
238268
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
239269
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
240270
)

0 commit comments

Comments
 (0)