11"""MMD (Maximum Mean Discrepancy) module."""
22
3- from typing import Callable , Optional , List , Union
3+ import itertools
4+ import math
5+ from typing import Callable , Iterator , Optional , List , Union
46
57import numpy as np # type: ignore
68from scipy .spatial .distance import cdist # type: ignore
9+ import tqdm # type: ignore
710
811from frouros .callbacks import Callback
912from frouros .detectors .data_drift .base import MultivariateData
@@ -43,12 +46,15 @@ class MMD(DistanceBasedBase):
4346 def __init__ (
4447 self ,
4548 kernel : Callable = rbf_kernel ,
49+ chunk_size : Optional [int ] = None ,
4650 callbacks : Optional [Union [Callback , List [Callback ]]] = None ,
4751 ) -> None :
4852 """Init method.
4953
5054 :param kernel: kernel function to use
5155 :type kernel: Callable
56+ :param chunk_size:
57+ :type chunk_size: Optional[int]
5258 :param callbacks: callbacks
5359 :type callbacks: Optional[Union[Callback, List[Callback]]]
5460 """
@@ -61,13 +67,42 @@ def __init__(
6167 callbacks = callbacks ,
6268 )
6369 self .kernel = kernel
70+ self .chunk_size = chunk_size
71+ self ._chunk_size_x = None
72+ self ._expected_k_x = None
73+ self ._X_num_samples = None
74+
75+ @property
76+ def chunk_size (self ) -> Optional [int ]:
77+ """Chunk size property.
78+
79+ :return: chunk size to use
80+ :rtype: int
81+ """
82+ return self ._chunk_size
83+
84+ @chunk_size .setter
85+ def chunk_size (self , value : Optional [int ]) -> None :
86+ """Chunk size method setter.
87+
88+ :param value: value to be set
89+ :type value: Optional[int]
90+ :raises TypeError: Type error exception
91+ """
92+ if value is not None :
93+ if isinstance (value , int ): # type: ignore
94+ if value <= 0 :
95+ raise ValueError ("chunk_size must be greater than 0 or None." )
96+ else :
97+ raise TypeError ("chunk_size must be of type int or None." )
98+ self ._chunk_size = value
6499
65100 @property
66101 def kernel (self ) -> Callable :
67102 """Kernel property.
68103
69104 :return: kernel function to use
70- :rtype: Kernel
105+ :rtype: Callable
71106 """
72107 return self ._kernel
73108
@@ -80,38 +115,147 @@ def kernel(self, value: Callable) -> None:
80115 :raises TypeError: Type error exception
81116 """
82117 if not isinstance (value , Callable ): # type: ignore
83- raise TypeError ("value must be of type Callable." )
118+ raise TypeError ("kernel must be of type Callable." )
84119 self ._kernel = value
85120
86121 def _distance_measure (
87122 self ,
88123 X_ref : np .ndarray , # noqa: N803
89124 X : np .ndarray , # noqa: N803
125+ ** kwargs ,
90126 ) -> DistanceResult :
91- mmd = self ._mmd (X = X_ref , Y = X , kernel = self .kernel )
127+ mmd = self ._mmd (X = X_ref , Y = X , kernel = self .kernel , ** kwargs )
92128 distance_test = DistanceResult (distance = mmd )
93129 return distance_test
94130
131+ def _fit (
132+ self ,
133+ X : np .ndarray , # noqa: N803
134+ ** kwargs ,
135+ ) -> None :
136+ super ()._fit (X = X )
137+ # Add dimension only for the kernel calculation (if dim == 1)
138+ if X .ndim == 1 :
139+ X = np .expand_dims (X , axis = 1 ) # noqa: N806
140+ self ._X_num_samples = len (X ) # type: ignore # noqa: N806
141+
142+ self ._chunk_size_x = (
143+ self ._X_num_samples
144+ if self .chunk_size is None
145+ else self .chunk_size # type: ignore
146+ )
147+
148+ X_chunks = self ._get_chunks ( # noqa: N806
149+ data = X ,
150+ chunk_size = self ._chunk_size_x , # type: ignore
151+ )
152+ X_chunks_combinations = itertools .product (X_chunks , repeat = 2 ) # noqa: N806
153+
154+ if kwargs .get ("verbose" , False ):
155+ num_chunks = (
156+ math .ceil (self ._X_num_samples / self ._chunk_size_x ) ** 2 # type: ignore
157+ )
158+ k_x_sum = np .array (
159+ [
160+ self .kernel (* X_chunk ).sum ()
161+ for X_chunk in tqdm .tqdm ( # noqa: N806
162+ X_chunks_combinations , total = num_chunks # noqa: N806
163+ )
164+ ]
165+ ).sum ()
166+ else :
167+ k_x_sum = np .array (
168+ [
169+ self .kernel (* X_chunk ).sum ()
170+ for X_chunk in X_chunks_combinations # noqa: N806
171+ ]
172+ ).sum ()
173+ self ._expected_k_x = k_x_sum / (
174+ self ._X_num_samples * (self ._X_num_samples - 1 ) # type: ignore
175+ )
176+
95177 @staticmethod
96- def _mmd (
178+ def _get_chunks (data : np .ndarray , chunk_size : int ) -> Iterator :
179+ chunks = (
180+ data [i : i + chunk_size ] # noqa: E203
181+ for i in range (0 , len (data ), chunk_size )
182+ )
183+ return chunks
184+
185+ def _mmd ( # pylint: disable=too-many-locals
186+ self ,
97187 X : np .ndarray , # noqa: N803
98188 Y : np .ndarray ,
99189 * ,
100190 kernel : Callable ,
191+ ** kwargs ,
101192 ) -> float : # noqa: N803
102- X_num_samples = X .shape [0 ] # noqa: N806
103- Y_num_samples = Y .shape [0 ] # noqa: N806
104- data = np .concatenate ([X , Y ]) # noqa: N806
193+ # Only check for X dimension (X == Y dim comparison has been already made)
105194 if X .ndim == 1 :
106- data = np .expand_dims (data , axis = 1 )
195+ X = np .expand_dims (X , axis = 1 ) # noqa: N806
196+ Y = np .expand_dims (Y , axis = 1 ) # noqa: N806
197+
198+ X_chunks = self ._get_chunks ( # noqa: N806
199+ data = X ,
200+ chunk_size = self ._chunk_size_x , # type: ignore
201+ )
202+ Y_num_samples = len (Y ) # noqa: N806
203+ chunk_size_y = Y_num_samples if self .chunk_size is None else self .chunk_size
204+ Y_chunks , Y_chunks_copy = itertools .tee ( # noqa: N806
205+ self ._get_chunks (
206+ data = Y ,
207+ chunk_size = chunk_size_y , # type: ignore
208+ ),
209+ 2 ,
210+ )
211+ Y_chunks_combinations = itertools .product ( # noqa: N806
212+ Y_chunks ,
213+ repeat = 2 ,
214+ )
215+ XY_chunks_combinations = itertools .product ( # noqa: N806
216+ X_chunks ,
217+ Y_chunks_copy ,
218+ )
219+
220+ if kwargs .get ("verbose" , False ):
221+ num_chunks_y = math .ceil (Y_num_samples / self .chunk_size ) # type: ignore
222+ num_chunks_y_combinations = num_chunks_y ** 2
223+ num_chunks_xy = (
224+ math .ceil (len (X ) / self ._chunk_size_x ) * num_chunks_y # type: ignore
225+ )
226+ sum_y = np .array (
227+ [
228+ kernel (* Y_chunk ).sum ()
229+ for Y_chunk in tqdm .tqdm ( # noqa: N806
230+ Y_chunks_combinations , total = num_chunks_y_combinations
231+ )
232+ ]
233+ ).sum ()
234+ sum_xy = np .array (
235+ [
236+ kernel (* XY_chunk ).sum ()
237+ for XY_chunk in tqdm .tqdm ( # noqa: N806
238+ XY_chunks_combinations , total = num_chunks_xy
239+ )
240+ ]
241+ ).sum ()
242+ else :
243+ sum_y = np .array (
244+ [
245+ kernel (* Y_chunk ).sum ()
246+ for Y_chunk in Y_chunks_combinations # noqa: N806
247+ ]
248+ ).sum ()
249+ sum_xy = np .array (
250+ [
251+ kernel (* XY_chunk ).sum ()
252+ for XY_chunk in XY_chunks_combinations # noqa: N806
253+ ]
254+ ).sum ()
107255
108- k_matrix = kernel (X = data , Y = data )
109- k_x = k_matrix [:X_num_samples , :X_num_samples ]
110- k_y = k_matrix [Y_num_samples :, Y_num_samples :]
111- k_xy = k_matrix [:X_num_samples , Y_num_samples :]
112256 mmd = (
113- k_x . sum () / ( X_num_samples * ( X_num_samples - 1 ))
114- + k_y . sum () / (Y_num_samples * (Y_num_samples - 1 ))
115- - 2 * k_xy . sum () / (X_num_samples * Y_num_samples )
257+ self . _expected_k_x
258+ + sum_y / (Y_num_samples * (Y_num_samples - 1 ))
259+ - 2 * sum_xy / (self . _X_num_samples * Y_num_samples ) # type: ignore
116260 )
117261 return mmd
0 commit comments