Skip to content

Commit 0f05005

Browse files
Add ResetStatisticalTest API example
1 parent fb6fd4f commit 0f05005

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

frouros/callbacks/batch/reset.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,33 @@
77

88

99
class ResetStatisticalTest(BaseCallbackBatch):
10-
"""Reset on statistical test batch callback class."""
10+
"""Reset callback class that can be applied to :mod:`data_drift.batch.statistical_test <frouros.detectors.data_drift.batch.statistical_test>` detectors.
1111
12-
def __init__(self, alpha: float, name: Optional[str] = None) -> None:
13-
"""Init method.
12+
:param alpha: significance value
13+
:type alpha: float
14+
:param name: name value, defaults to None. If None, the name will be set to `ResetStatisticalTest`.
15+
:type name: Optional[str]
1416
15-
:param alpha: significance value
16-
:type alpha: float
17-
:param name: name to be use
18-
:type name: Optional[str]
19-
"""
17+
:Example:
18+
19+
>>> from frouros.callbacks import ResetStatisticalTest
20+
>>> from frouros.detectors.data_drift import KSTest
21+
>>> import numpy as np
22+
>>> np.random.seed(seed=31)
23+
>>> X = np.random.normal(loc=0, scale=1, size=100)
24+
>>> Y = np.random.normal(loc=1, scale=1, size=100)
25+
>>> detector = KSTest(callbacks=ResetStatisticalTest(alpha=0.01))
26+
>>> _ = detector.fit(X=X)
27+
>>> detector.compare(X=Y)[0]
28+
INFO:frouros:Drift detected. Resetting detector...
29+
StatisticalResult(statistic=0.55, p_value=3.0406585087050305e-14)
30+
""" # noqa: E501 # pylint: disable=line-too-long
31+
32+
def __init__( # noqa: D107
33+
self,
34+
alpha: float,
35+
name: Optional[str] = None,
36+
) -> None:
2037
super().__init__(name=name)
2138
self.alpha = alpha
2239

frouros/tests/integration/test_callback.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
BaseSPC,
3535
)
3636
from frouros.detectors.data_drift.batch import (
37+
AndersonDarlingTest,
3738
BhattacharyyaDistance,
3839
CVMTest,
3940
EMD,
@@ -42,6 +43,7 @@
4243
JS,
4344
KL,
4445
KSTest,
46+
MannWhitneyUTest,
4547
MMD,
4648
PSI,
4749
WelchTTest,
@@ -108,7 +110,7 @@ def test_batch_permutation_test_data_univariate_different_distribution(
108110

109111
@pytest.mark.parametrize(
110112
"detector_class",
111-
[CVMTest, KSTest, WelchTTest],
113+
[AndersonDarlingTest, CVMTest, KSTest, MannWhitneyUTest, WelchTTest],
112114
)
113115
def test_batch_reset_on_statistical_test_data_drift(
114116
X_ref_univariate, # noqa: N803

0 commit comments

Comments
 (0)