Skip to content

Commit efc5460

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #275 from IFCA/feature-non-zero-pvalues-permutationtest
Add flag to avoid zero p-values in PermutationTestDistanceBased callback
2 parents 692e2d9 + ce8f48a commit efc5460

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

frouros/callbacks/batch/permutation_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class PermutationTestDistanceBased(BaseCallbackBatch):
1616
:type num_permutations: int
1717
:param num_jobs: number of jobs, defaults to -1
1818
:type num_jobs: int
19+
:param conservative: conservative flag, defaults to False. If False, the p-value can be zero `(#permuted_statistics >= observed_statistic) / num_permutations`. If True, uses the conservative approach to avoid zero p-value `((#permuted_statistics >= observed_statistic) + 1) / (num_permutations + 1)`.
20+
:type conservative: bool
21+
:param random_state: random state, defaults to None
22+
:type random_state: Optional[int]
1923
:param verbose: verbose flag, defaults to False
2024
:type verbose: bool
2125
:param name: name value, defaults to None. If None, the name will be set to `PermutationTestDistanceBased`.
@@ -49,15 +53,17 @@ def __init__( # noqa: D107
4953
self,
5054
num_permutations: int,
5155
num_jobs: int = -1,
56+
conservative: bool = False,
57+
random_state: Optional[int] = None,
5258
verbose: bool = False,
5359
name: Optional[str] = None,
54-
**kwargs,
5560
) -> None:
5661
super().__init__(name=name)
5762
self.num_permutations = num_permutations
5863
self.num_jobs = num_jobs
64+
self.conservative = conservative
65+
self.random_state = random_state
5966
self.verbose = verbose
60-
self.permutation_kwargs = kwargs
6167

6268
@property
6369
def num_permutations(self) -> int:
@@ -101,6 +107,27 @@ def num_jobs(self, value: int) -> None:
101107
raise ValueError("value must be greater than 0 or -1.")
102108
self._num_jobs = multiprocessing.cpu_count() if value == -1 else value
103109

110+
@property
111+
def conservative(self) -> bool:
112+
"""Conservative (avoid zero p-value) flag property.
113+
114+
:return: conservative flag
115+
:rtype: bool
116+
"""
117+
return self._conservative
118+
119+
@conservative.setter
120+
def conservative(self, value: bool) -> None:
121+
"""Conservative (avoid zero p-value) flag setter.
122+
123+
:param value: value to be set
124+
:type value: bool
125+
:raises TypeError: Type error exception
126+
"""
127+
if not isinstance(value, bool):
128+
raise TypeError("value must of type bool.")
129+
self._conservative = value
130+
104131
@property
105132
def verbose(self) -> bool:
106133
"""Verbose flag property.
@@ -131,7 +158,8 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
131158
observed_statistic: float,
132159
num_permutations: int,
133160
num_jobs: int,
134-
random_state: int,
161+
conservative: bool,
162+
random_state: Optional[int],
135163
verbose: bool,
136164
) -> Tuple[List[float], float]:
137165
permuted_statistic = permutation(
@@ -145,7 +173,12 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
145173
verbose=verbose,
146174
)
147175
permuted_statistic = np.array(permuted_statistic)
148-
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
176+
p_value = (
177+
((permuted_statistic >= observed_statistic).sum() + 1) # type: ignore
178+
/ (num_permutations + 1)
179+
if conservative
180+
else (permuted_statistic >= observed_statistic).mean() # type: ignore
181+
)
149182
return permuted_statistic, p_value
150183

151184
def on_compare_end(
@@ -172,8 +205,9 @@ def on_compare_end(
172205
observed_statistic=observed_statistic,
173206
num_permutations=self.num_permutations,
174207
num_jobs=self.num_jobs,
208+
conservative=self.conservative,
209+
random_state=self.random_state,
175210
verbose=self.verbose,
176-
**self.permutation_kwargs,
177211
)
178212
self.logs.update(
179213
{

frouros/tests/integration/test_callback.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_batch_permutation_test_data_univariate_different_distribution(
6565
expected_distance: float,
6666
expected_p_value: float,
6767
) -> None:
68-
"""Test batch permutation test on data callback.
68+
"""Test batch permutation test on data drift callback.
6969
7070
:param X_ref_univariate: reference univariate data
7171
:type X_ref_univariate: numpy.ndarray
@@ -101,6 +101,40 @@ def test_batch_permutation_test_data_univariate_different_distribution(
101101
)
102102

103103

104+
def test_batch_permutation_test_conservative(
105+
X_ref_univariate: np.ndarray, # noqa: N803
106+
X_test_univariate: np.ndarray,
107+
) -> None:
108+
"""Test batch permutation test on data drift callback using conservative flag.
109+
110+
:param X_ref_univariate: reference univariate data
111+
:type X_ref_univariate: numpy.ndarray
112+
:param X_test_univariate: test univariate data
113+
:type X_test_univariate: numpy.ndarray
114+
"""
115+
np.random.seed(seed=31)
116+
117+
permutation_test_name = "permutation_test"
118+
detector = MMD( # type: ignore
119+
callbacks=[
120+
PermutationTestDistanceBased(
121+
num_permutations=100,
122+
conservative=True,
123+
random_state=31,
124+
num_jobs=-1,
125+
name=permutation_test_name,
126+
)
127+
]
128+
)
129+
_ = detector.fit(X=X_ref_univariate)
130+
_, callback_logs = detector.compare(X=X_test_univariate)
131+
132+
assert np.isclose(
133+
callback_logs[permutation_test_name]["p_value"],
134+
0.00990099,
135+
)
136+
137+
104138
@pytest.mark.parametrize(
105139
"detector_class",
106140
[AndersonDarlingTest, CVMTest, KSTest, MannWhitneyUTest, WelchTTest],

0 commit comments

Comments
 (0)