@@ -51,10 +51,6 @@ def __init__(
5151 )
5252 self .kernel = kernel
5353 self .chunk_size = chunk_size
54- self ._chunk_size_x = None
55- self .X_chunks_combinations = None
56- self .X_num_samples = None
57- self .expected_k_xx = None
5854
5955 @property
6056 def chunk_size (self ) -> Optional [int ]:
@@ -108,55 +104,16 @@ def _distance_measure(
108104 X : np .ndarray , # noqa: N803
109105 ** kwargs ,
110106 ) -> DistanceResult :
111- mmd = self ._mmd (X = X_ref , Y = X , kernel = self .kernel , ** kwargs )
107+ mmd = self ._mmd (
108+ X = X_ref ,
109+ Y = X ,
110+ kernel = self .kernel ,
111+ chunk_size = self .chunk_size ,
112+ ** kwargs ,
113+ )
112114 distance_test = DistanceResult (distance = mmd )
113115 return distance_test
114116
115- def _fit (
116- self ,
117- X : np .ndarray , # noqa: N803
118- ** kwargs ,
119- ) -> None :
120- super ()._fit (X = X )
121- # Add dimension only for the kernel calculation (if dim == 1)
122- if X .ndim == 1 :
123- X = np .expand_dims (X , axis = 1 ) # noqa: N806
124- self .X_num_samples = len (self .X_ref ) # type: ignore # noqa: N806
125-
126- self ._chunk_size_x = (
127- self .X_num_samples
128- if self .chunk_size is None
129- else self .chunk_size # type: ignore
130- )
131-
132- X_chunks = self ._get_chunks ( # noqa: N806
133- data = X ,
134- chunk_size = self ._chunk_size_x , # type: ignore
135- )
136- xx_chunks_combinations = itertools .product (X_chunks , repeat = 2 ) # noqa: N806
137-
138- if kwargs .get ("verbose" , False ):
139- num_chunks = (
140- math .ceil (self .X_num_samples / self ._chunk_size_x ) ** 2 # type: ignore
141- )
142- xx_chunks_combinations = tqdm .tqdm (
143- xx_chunks_combinations ,
144- total = num_chunks ,
145- )
146-
147- k_xx_sum = (
148- self ._compute_kernel (
149- chunk_combinations = xx_chunks_combinations , # type: ignore
150- kernel = self .kernel ,
151- )
152- # Remove diagonal (j!=i case)
153- - self .X_num_samples # type: ignore
154- )
155-
156- self .expected_k_xx = k_xx_sum / ( # type: ignore
157- self .X_num_samples * (self .X_num_samples - 1 ) # type: ignore
158- )
159-
160117 @staticmethod
161118 def _compute_kernel (chunk_combinations : Generator , kernel : Callable ) -> float :
162119 k_sum = np .array ([kernel (* chunk ).sum () for chunk in chunk_combinations ]).sum ()
@@ -170,8 +127,8 @@ def _get_chunks(data: np.ndarray, chunk_size: int) -> Generator:
170127 )
171128 return chunks
172129
130+ @staticmethod
173131 def _mmd ( # pylint: disable=too-many-locals
174- self ,
175132 X : np .ndarray , # noqa: N803
176133 Y : np .ndarray ,
177134 * ,
@@ -183,33 +140,56 @@ def _mmd( # pylint: disable=too-many-locals
183140 X = np .expand_dims (X , axis = 1 ) # noqa: N806
184141 Y = np .expand_dims (Y , axis = 1 ) # noqa: N806
185142
186- X_chunks = self ._get_chunks ( # noqa: N806
187- data = X ,
188- chunk_size = self ._chunk_size_x , # type: ignore
143+ x_num_samples = len (X ) # noqa: N806
144+ chunk_size_x = (
145+ kwargs ["chunk_size" ]
146+ if "chunk_size" in kwargs and kwargs ["chunk_size" ] is not None
147+ else x_num_samples
148+ )
149+ x_chunks , x_chunks_copy = itertools .tee ( # noqa: N806
150+ MMD ._get_chunks (
151+ data = X ,
152+ chunk_size = chunk_size_x , # type: ignore
153+ ),
154+ 2 ,
189155 )
190156 y_num_samples = len (Y ) # noqa: N806
191- chunk_size_y = y_num_samples if self .chunk_size is None else self .chunk_size
157+ chunk_size_y = (
158+ kwargs ["chunk_size" ]
159+ if "chunk_size" in kwargs and kwargs ["chunk_size" ] is not None
160+ else y_num_samples
161+ )
192162 y_chunks , y_chunks_copy = itertools .tee ( # noqa: N806
193- self ._get_chunks (
163+ MMD ._get_chunks (
194164 data = Y ,
195165 chunk_size = chunk_size_y , # type: ignore
196166 ),
197167 2 ,
198168 )
169+ x_chunks_combinations = itertools .product ( # noqa: N806
170+ x_chunks ,
171+ repeat = 2 ,
172+ )
199173 y_chunks_combinations = itertools .product ( # noqa: N806
200174 y_chunks ,
201175 repeat = 2 ,
202176 )
203177 xy_chunks_combinations = itertools .product ( # noqa: N806
204- X_chunks ,
178+ x_chunks_copy ,
205179 y_chunks_copy ,
206180 )
207181
208182 if kwargs .get ("verbose" , False ):
183+ num_chunks_x = math .ceil (x_num_samples / chunk_size_x ) # type: ignore
209184 num_chunks_y = math .ceil (y_num_samples / chunk_size_y ) # type: ignore
185+ num_chunks_x_combinations = num_chunks_x ** 2
210186 num_chunks_y_combinations = num_chunks_y ** 2
211187 num_chunks_xy = (
212- math .ceil (len (X ) / self ._chunk_size_x ) * num_chunks_y # type: ignore
188+ math .ceil (len (X ) / chunk_size_x ) * num_chunks_y # type: ignore
189+ )
190+ x_chunks_combinations = tqdm .tqdm (
191+ x_chunks_combinations ,
192+ total = num_chunks_x_combinations ,
213193 )
214194 y_chunks_combinations = tqdm .tqdm (
215195 y_chunks_combinations ,
@@ -220,21 +200,29 @@ def _mmd( # pylint: disable=too-many-locals
220200 total = num_chunks_xy ,
221201 )
222202
203+ k_xx_sum = (
204+ MMD ._compute_kernel (
205+ chunk_combinations = x_chunks_combinations , # type: ignore
206+ kernel = kernel ,
207+ )
208+ # Remove diagonal (j!=i case)
209+ - x_num_samples # type: ignore
210+ )
223211 k_yy_sum = (
224- self ._compute_kernel (
212+ MMD ._compute_kernel (
225213 chunk_combinations = y_chunks_combinations , # type: ignore
226214 kernel = kernel ,
227215 )
228216 # Remove diagonal (j!=i case)
229217 - y_num_samples # type: ignore
230218 )
231- k_xy_sum = self ._compute_kernel (
219+ k_xy_sum = MMD ._compute_kernel (
232220 chunk_combinations = xy_chunks_combinations , # type: ignore
233221 kernel = kernel ,
234222 )
235223 mmd = (
236- self . expected_k_xx # type: ignore
224+ + k_xx_sum / ( x_num_samples * ( x_num_samples - 1 ))
237225 + k_yy_sum / (y_num_samples * (y_num_samples - 1 ))
238- - 2 * k_xy_sum / (self . X_num_samples * y_num_samples ) # type: ignore
226+ - 2 * k_xy_sum / (x_num_samples * y_num_samples ) # type: ignore
239227 )
240228 return mmd
0 commit comments