Skip to content

Commit 87ae546

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #203 from IFCA/fix-permutation-tests
Fix permutation tests calculating p-values with z-score
2 parents 8dc64af + 7f11d96 commit 87ae546

File tree

4 files changed

+104
-74
lines changed

4 files changed

+104
-74
lines changed

docs/source/examples/data_drift/MMD_advance.ipynb

Lines changed: 63 additions & 63 deletions
Large diffs are not rendered by default.

frouros/callbacks/batch/permutation_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from typing import Any, Callable, Dict, List, Optional, Tuple
55

66
import numpy as np # type: ignore
7+
from scipy.stats import norm # type: ignore
78

89
from frouros.callbacks.batch.base import BatchCallback
9-
from frouros.utils.stats import permutation
10+
from frouros.utils.stats import permutation, z_score
1011

1112

1213
class PermutationTestOnBatchData(BatchCallback):
@@ -120,7 +121,14 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
120121
random_state=random_state,
121122
verbose=verbose,
122123
)
123-
p_value = (observed_statistic < permuted_statistic).mean() # type: ignore
124+
permuted_statistic = np.array(permuted_statistic)
125+
# Use z-score to calculate p-value
126+
observed_z_score = z_score(
127+
value=observed_statistic,
128+
mean=permuted_statistic.mean(), # type: ignore
129+
std=permuted_statistic.std(), # type: ignore
130+
)
131+
p_value = norm.sf(np.abs(observed_z_score)) * 2
124132
return permuted_statistic, p_value
125133

126134
def on_compare_end(self, **kwargs) -> None:

frouros/tests/integration/test_callback.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
"detector_class, expected_distance, expected_p_value",
5555
[
5656
(BhattacharyyaDistance, 0.55516059, 0.0),
57-
(EMD, 3.85346006, 0.0),
58-
(HellingerDistance, 0.74509099, 0.0),
59-
(HINormalizedComplement, 0.78, 0.0),
60-
(JS, 0.67010107, 0.0),
61-
(KL, np.inf, 0.0),
62-
(MMD, 0.69509004, 0.0),
63-
(PSI, 461.20379435, 0.0),
57+
(EMD, 3.85346006, 9.21632493e-101),
58+
(HellingerDistance, 0.74509099, 3.13808126e-50),
59+
(HINormalizedComplement, 0.78, 1.31340683e-55),
60+
(JS, 0.67010107, 2.30485343e-63),
61+
(KL, np.inf, np.nan),
62+
(MMD, 0.69509004, 2.53277069e-137),
63+
(PSI, 461.20379435, 4.45088795e-238),
6464
],
6565
)
6666
def test_batch_permutation_test_data_univariate_different_distribution(
@@ -100,7 +100,11 @@ def test_batch_permutation_test_data_univariate_different_distribution(
100100
distance, callback_logs = detector.compare(X=X_test_univariate)
101101

102102
assert np.isclose(distance, expected_distance)
103-
assert np.isclose(callback_logs[permutation_test_name]["p_value"], expected_p_value)
103+
assert np.isclose(
104+
callback_logs[permutation_test_name]["p_value"],
105+
expected_p_value,
106+
equal_nan=True,
107+
)
104108

105109

106110
@pytest.mark.parametrize(

frouros/utils/stats.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,5 +262,23 @@ def permutation( # pylint: disable=too-many-arguments,too-many-locals
262262
iterable=tqdm(permuted_data) if verbose else permuted_data,
263263
).get()
264264

265-
# FIXME: explore if abs must be used in permuted_statistic # pylint: disable=fixme
266265
return permuted_statistics
266+
267+
268+
def z_score(
269+
value: np.ndarray,
270+
mean: float,
271+
std: float,
272+
) -> np.ndarray:
273+
"""Z-score method.
274+
275+
:param value: value to use to compute the z-score
276+
:type value: np.ndarray
277+
:param mean: mean value
278+
:type mean: float
279+
:param std: standard deviation value
280+
:type std: float
281+
:return: z-score
282+
:rtype: np.ndarray
283+
"""
284+
return (value - mean) / std

0 commit comments

Comments
 (0)