11"""MMD (Maximum Mean Discrepancy) module."""
22
33import itertools
4- import math
54from typing import Callable , Generator , Optional , List , Union
65
76import numpy as np # type: ignore
8- import tqdm # type: ignore
97
108from frouros .callbacks .batch .base import BaseCallbackBatch
119from frouros .detectors .data_drift .base import MultivariateData
@@ -137,7 +135,7 @@ def _fit(
137135 # Add dimension only for the kernel calculation (if dim == 1)
138136 if X .ndim == 1 :
139137 X = np .expand_dims (X , axis = 1 ) # noqa: N806
140- x_num_samples = len (self .X_ref ) # type: ignore # noqa: N806
138+ x_num_samples = len (self .X_ref ) # type: ignore
141139
142140 chunk_size_x = (
143141 x_num_samples
@@ -147,7 +145,7 @@ def _fit(
147145
148146 x_chunks = self ._get_chunks ( # noqa: N806
149147 data = X ,
150- chunk_size = chunk_size_x , # type: ignore
148+ chunk_size = chunk_size_x ,
151149 )
152150 x_chunks_combinations = itertools .product (x_chunks , repeat = 2 ) # noqa: N806
153151
@@ -157,11 +155,11 @@ def _fit(
157155 kernel = self .kernel ,
158156 )
159157 # Remove diagonal (j!=i case)
160- - x_num_samples # type: ignore
158+ - x_num_samples
161159 )
162160
163161 self ._expected_k_xx = k_xx_sum / ( # type: ignore
164- x_num_samples * (x_num_samples - 1 ) # type: ignore
162+ x_num_samples * (x_num_samples - 1 )
165163 )
166164
167165 @staticmethod
@@ -201,33 +199,31 @@ def _mmd( # pylint: disable=too-many-locals
201199 if "expected_k_xx" in kwargs :
202200 x_chunks_copy = MMD ._get_chunks ( # noqa: N806
203201 data = X ,
204- chunk_size = chunk_size_x , # type: ignore
202+ chunk_size = chunk_size_x ,
205203 )
206204 expected_k_xx = kwargs ["expected_k_xx" ]
207205 else :
208206 # Compute expected_k_xx
209- x_chunks , x_chunks_copy = itertools .tee ( # noqa: N806
207+ x_chunks , x_chunks_copy = itertools .tee ( # type: ignore
210208 MMD ._get_chunks (
211209 data = X ,
212- chunk_size = chunk_size_x , # type: ignore
210+ chunk_size = chunk_size_x ,
213211 ),
214212 2 ,
215213 )
216- x_chunks_combinations = itertools .product ( # noqa: N806
214+ x_chunks_combinations = itertools .product ( # type: ignore
217215 x_chunks ,
218216 repeat = 2 ,
219217 )
220218 k_xx_sum = (
221- MMD ._compute_kernel (
222- chunk_combinations = x_chunks_combinations , # type: ignore
223- kernel = kernel ,
224- )
225- # Remove diagonal (j!=i case)
226- - x_num_samples # type: ignore
227- )
228- expected_k_xx = k_xx_sum / ( # type: ignore
229- x_num_samples * (x_num_samples - 1 ) # type: ignore
219+ MMD ._compute_kernel (
220+ chunk_combinations = x_chunks_combinations , # type: ignore
221+ kernel = kernel ,
222+ )
223+ # Remove diagonal (j!=i case)
224+ - x_num_samples
230225 )
226+ expected_k_xx = k_xx_sum / (x_num_samples * (x_num_samples - 1 ))
231227
232228 y_num_samples = len (Y ) # noqa: N806
233229 chunk_size_y = (
@@ -238,7 +234,7 @@ def _mmd( # pylint: disable=too-many-locals
238234 y_chunks , y_chunks_copy = itertools .tee ( # noqa: N806
239235 MMD ._get_chunks (
240236 data = Y ,
241- chunk_size = chunk_size_y , # type: ignore
237+ chunk_size = chunk_size_y ,
242238 ),
243239 2 ,
244240 )
@@ -257,15 +253,15 @@ def _mmd( # pylint: disable=too-many-locals
257253 kernel = kernel ,
258254 )
259255 # Remove diagonal (j!=i case)
260- - y_num_samples # type: ignore
256+ - y_num_samples
261257 )
262258 k_xy_sum = MMD ._compute_kernel (
263259 chunk_combinations = xy_chunks_combinations , # type: ignore
264260 kernel = kernel ,
265261 )
266262 mmd = (
267- + expected_k_xx
263+ + expected_k_xx
268264 + k_yy_sum / (y_num_samples * (y_num_samples - 1 ))
269- - 2 * k_xy_sum / (x_num_samples * y_num_samples ) # type: ignore
265+ - 2 * k_xy_sum / (x_num_samples * y_num_samples )
270266 )
271267 return mmd
0 commit comments