99from frouros .detectors .data_drift import MMD # type: ignore
1010from frouros .utils .kernels import rbf_kernel
1111
12+ RANDOM_SEED = 31
13+ DEFAULT_SIGMA = 0.5
14+
1215
1316@pytest .mark .parametrize (
1417 "distribution_p, distribution_q, expected_distance" ,
@@ -35,12 +38,15 @@ def test_mmd_batch_univariate(
3538 :param expected_distance: expected distance value
3639 :type expected_distance: float
3740 """
38- np .random .seed (seed = 31 )
41+ np .random .seed (seed = RANDOM_SEED )
3942 X_ref = np .random .normal (* distribution_p ) # noqa: N806
4043 X_test = np .random .normal (* distribution_q ) # noqa: N806
4144
4245 detector = MMD (
43- kernel = partial (rbf_kernel , sigma = 0.5 ),
46+ kernel = partial (
47+ rbf_kernel ,
48+ sigma = DEFAULT_SIGMA ,
49+ ),
4450 )
4551 _ = detector .fit (X = X_ref )
4652
@@ -77,11 +83,14 @@ def test_mmd_batch_precomputed_expected_k_xx(
7783 :param chunk_size: chunk size
7884 :type chunk_size: Optional[int]
7985 """
80- np .random .seed (seed = 31 )
86+ np .random .seed (seed = RANDOM_SEED )
8187 X_ref = np .random .normal (* distribution_p ) # noqa: N806
8288 X_test = np .random .normal (* distribution_q ) # noqa: N806
8389
84- kernel = partial (rbf_kernel , sigma = 0.5 )
90+ kernel = partial (
91+ rbf_kernel ,
92+ sigma = DEFAULT_SIGMA ,
93+ )
8594
8695 detector = MMD (
8796 kernel = kernel ,
@@ -101,3 +110,53 @@ def test_mmd_batch_precomputed_expected_k_xx(
101110 )
102111
103112 assert np .isclose (precomputed_distance , scratch_distance )
113+
114+
115+ @pytest .mark .parametrize (
116+ "distribution_p, distribution_q, chunk_size" ,
117+ [
118+ ((0 , 1 , size ), (2 , 1 , size ), chunk_size )
119+ for size in [10 , 100 ]
120+ for chunk_size in list (range (1 , 11 ))
121+ ],
122+ )
123+ def test_mmd_chunk_size_equivalence (
124+ distribution_p : Tuple [float , float , int ],
125+ distribution_q : Tuple [float , float , int ],
126+ chunk_size : int ,
127+ ) -> None :
128+ """Test MMD with chunk_size=None vs specific chunk_size.
129+
130+ :param distribution_p: mean, std and size of samples from distribution p
131+ :type distribution_p: Tuple[float, float, int]
132+ :param distribution_q: mean, std and size of samples from distribution q
133+ :type distribution_q: Tuple[float, float, int]
134+ :param chunk_size: specific chunk size to compare with None
135+ :type chunk_size: int
136+ """
137+ np .random .seed (seed = RANDOM_SEED )
138+ X_ref = np .random .normal (* distribution_p ) # noqa: N806
139+ X_test = np .random .normal (* distribution_q ) # noqa: N806
140+
141+ kernel = partial (
142+ rbf_kernel ,
143+ sigma = DEFAULT_SIGMA ,
144+ )
145+
146+ # Detector with chunk_size=None
147+ detector_none = MMD (
148+ kernel = kernel ,
149+ chunk_size = None ,
150+ )
151+ _ = detector_none .fit (X = X_ref )
152+ result_none = detector_none .compare (X = X_test )[0 ].distance
153+
154+ # Detector with specific chunk_size
155+ detector_chunk = MMD (
156+ kernel = kernel ,
157+ chunk_size = chunk_size ,
158+ )
159+ _ = detector_chunk .fit (X = X_ref )
160+ result_chunk = detector_chunk .compare (X = X_test )[0 ].distance
161+
162+ assert np .isclose (result_none , result_chunk )
0 commit comments