Skip to content

Commit 6b19761

Browse files
Fix p-values workaround computation
1 parent f04d4f8 commit 6b19761

File tree

2 files changed

+9
-17
lines changed

2 files changed

+9
-17
lines changed

frouros/callbacks/batch/permutation_test.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from typing import Any, Callable, Dict, List, Optional, Tuple
55

66
import numpy as np # type: ignore
7-
from scipy.stats import norm # type: ignore
87

98
from frouros.callbacks.batch.base import BaseCallbackBatch
10-
from frouros.utils.stats import permutation, z_score
9+
from frouros.utils.stats import permutation
1110

1211

1312
class PermutationTestDistanceBased(BaseCallbackBatch):
@@ -122,13 +121,7 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
122121
verbose=verbose,
123122
)
124123
permuted_statistic = np.array(permuted_statistic)
125-
# Use z-score to calculate p-value
126-
observed_z_score = z_score(
127-
value=observed_statistic,
128-
mean=permuted_statistic.mean(), # type: ignore
129-
std=permuted_statistic.std(), # type: ignore
130-
)
131-
p_value = norm.sf(np.abs(observed_z_score)) * 2
124+
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
132125
return permuted_statistic, p_value
133126

134127
def on_compare_end(self, **kwargs) -> None:

frouros/tests/integration/test_callback.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
"detector_class, expected_distance, expected_p_value",
5555
[
5656
(BhattacharyyaDistance, 0.55516059, 0.0),
57-
(EMD, 3.85346006, 9.21632493e-101),
58-
(HellingerDistance, 0.74509099, 3.13808126e-50),
59-
(HINormalizedComplement, 0.78, 1.31340683e-55),
60-
(JS, 0.67010107, 2.30485343e-63),
61-
(KL, np.inf, np.nan),
62-
(MMD, 0.69509004, 2.53277069e-137),
63-
(PSI, 461.20379435, 4.45088795e-238),
57+
(EMD, 3.85346006, 0.0),
58+
(HellingerDistance, 0.74509099, 0.0),
59+
(HINormalizedComplement, 0.78, 0.0),
60+
(JS, 0.67010107, 0.0),
61+
(KL, np.inf, 0.06),
62+
(MMD, 0.69509004, 0.0),
63+
(PSI, 461.20379435, 0.0),
6464
],
6565
)
6666
def test_batch_permutation_test_data_univariate_different_distribution(
@@ -103,7 +103,6 @@ def test_batch_permutation_test_data_univariate_different_distribution(
103103
assert np.isclose(
104104
callback_logs[permutation_test_name]["p_value"],
105105
expected_p_value,
106-
equal_nan=True,
107106
)
108107

109108

0 commit comments

Comments
 (0)