@@ -64,6 +64,7 @@ def __init__( # noqa: D107
6464 )
6565 self .kernel = kernel
6666 self .chunk_size = chunk_size
67+ self ._expected_k_xx = None
6768
6869 @property
6970 def chunk_size (self ) -> Optional [int ]:
@@ -122,11 +123,47 @@ def _distance_measure(
122123 Y = X ,
123124 kernel = self .kernel ,
124125 chunk_size = self .chunk_size ,
126+ expected_k_xx = self ._expected_k_xx ,
125127 ** kwargs ,
126128 )
127129 distance_test = DistanceResult (distance = mmd )
128130 return distance_test
129131
132+ def _fit (
133+ self ,
134+ X : np .ndarray , # noqa: N803
135+ ) -> None :
136+ super ()._fit (X = X )
137+ # Add dimension only for the kernel calculation (if dim == 1)
138+ if X .ndim == 1 :
139+ X = np .expand_dims (X , axis = 1 ) # noqa: N806
140+ x_num_samples = len (self .X_ref ) # type: ignore # noqa: N806
141+
142+ chunk_size_x = (
143+ x_num_samples
144+ if self .chunk_size is None
145+ else self .chunk_size # type: ignore
146+ )
147+
148+ x_chunks = self ._get_chunks ( # noqa: N806
149+ data = X ,
150+ chunk_size = chunk_size_x , # type: ignore
151+ )
152+ x_chunks_combinations = itertools .product (x_chunks , repeat = 2 ) # noqa: N806
153+
154+ k_xx_sum = (
155+ self ._compute_kernel (
156+ chunk_combinations = x_chunks_combinations , # type: ignore
157+ kernel = self .kernel ,
158+ )
159+ # Remove diagonal (j!=i case)
160+ - x_num_samples # type: ignore
161+ )
162+
163+ self ._expected_k_xx = k_xx_sum / ( # type: ignore
164+ x_num_samples * (x_num_samples - 1 ) # type: ignore
165+ )
166+
130167 @staticmethod
131168 def _compute_kernel (chunk_combinations : Generator , kernel : Callable ) -> float :
132169 k_sum = np .array ([kernel (* chunk ).sum () for chunk in chunk_combinations ]).sum ()
@@ -159,13 +196,39 @@ def _mmd( # pylint: disable=too-many-locals
159196 if "chunk_size" in kwargs and kwargs ["chunk_size" ] is not None
160197 else x_num_samples
161198 )
162- x_chunks , x_chunks_copy = itertools .tee ( # noqa: N806
163- MMD ._get_chunks (
199+
200+ # If expected_k_xx is provided, we don't need to compute it again
201+ if "expected_k_xx" in kwargs :
202+ x_chunks_copy = MMD ._get_chunks ( # noqa: N806
164203 data = X ,
165204 chunk_size = chunk_size_x , # type: ignore
166- ),
167- 2 ,
168- )
205+ )
206+ expected_k_xx = kwargs ["expected_k_xx" ]
207+ else :
208+ # Compute expected_k_xx
209+ x_chunks , x_chunks_copy = itertools .tee ( # noqa: N806
210+ MMD ._get_chunks (
211+ data = X ,
212+ chunk_size = chunk_size_x , # type: ignore
213+ ),
214+ 2 ,
215+ )
216+ x_chunks_combinations = itertools .product ( # noqa: N806
217+ x_chunks ,
218+ repeat = 2 ,
219+ )
220+ k_xx_sum = (
221+ MMD ._compute_kernel (
222+ chunk_combinations = x_chunks_combinations , # type: ignore
223+ kernel = kernel ,
224+ )
225+ # Remove diagonal (j!=i case)
226+ - x_num_samples # type: ignore
227+ )
228+ expected_k_xx = k_xx_sum / ( # type: ignore
229+ x_num_samples * (x_num_samples - 1 ) # type: ignore
230+ )
231+
169232 y_num_samples = len (Y ) # noqa: N806
170233 chunk_size_y = (
171234 kwargs ["chunk_size" ]
@@ -179,10 +242,6 @@ def _mmd( # pylint: disable=too-many-locals
179242 ),
180243 2 ,
181244 )
182- x_chunks_combinations = itertools .product ( # noqa: N806
183- x_chunks ,
184- repeat = 2 ,
185- )
186245 y_chunks_combinations = itertools .product ( # noqa: N806
187246 y_chunks ,
188247 repeat = 2 ,
@@ -192,35 +251,6 @@ def _mmd( # pylint: disable=too-many-locals
192251 y_chunks_copy ,
193252 )
194253
195- if kwargs .get ("verbose" , False ):
196- num_chunks_x = math .ceil (x_num_samples / chunk_size_x ) # type: ignore
197- num_chunks_y = math .ceil (y_num_samples / chunk_size_y ) # type: ignore
198- num_chunks_x_combinations = num_chunks_x ** 2
199- num_chunks_y_combinations = num_chunks_y ** 2
200- num_chunks_xy = (
201- math .ceil (len (X ) / chunk_size_x ) * num_chunks_y # type: ignore
202- )
203- x_chunks_combinations = tqdm .tqdm (
204- x_chunks_combinations ,
205- total = num_chunks_x_combinations ,
206- )
207- y_chunks_combinations = tqdm .tqdm (
208- y_chunks_combinations ,
209- total = num_chunks_y_combinations ,
210- )
211- xy_chunks_combinations = tqdm .tqdm (
212- xy_chunks_combinations ,
213- total = num_chunks_xy ,
214- )
215-
216- k_xx_sum = (
217- MMD ._compute_kernel (
218- chunk_combinations = x_chunks_combinations , # type: ignore
219- kernel = kernel ,
220- )
221- # Remove diagonal (j!=i case)
222- - x_num_samples # type: ignore
223- )
224254 k_yy_sum = (
225255 MMD ._compute_kernel (
226256 chunk_combinations = y_chunks_combinations , # type: ignore
@@ -234,7 +264,7 @@ def _mmd( # pylint: disable=too-many-locals
234264 kernel = kernel ,
235265 )
236266 mmd = (
237- + k_xx_sum / ( x_num_samples * ( x_num_samples - 1 ))
267+ + expected_k_xx
238268 + k_yy_sum / (y_num_samples * (y_num_samples - 1 ))
239269 - 2 * k_xy_sum / (x_num_samples * y_num_samples ) # type: ignore
240270 )
0 commit comments