Skip to content

Commit fb6fd4f

Browse files
Add PermutationTestDistanceBased API example
1 parent 84f4704 commit fb6fd4f

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

frouros/callbacks/batch/permutation_test.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,49 @@
1010

1111

1212
class PermutationTestDistanceBased(BaseCallbackBatch):
13-
"""Permutation test on distance based batch callback class."""
14-
15-
def __init__(
13+
"""Permutation test callback class that can be applied to :mod:`data_drift.batch.distance_based <frouros.detectors.data_drift.batch.distance_based>` detectors.
14+
15+
:param num_permutations: number of permutations to obtain the p-value
16+
:type num_permutations: int
17+
:param num_jobs: number of jobs, defaults to -1
18+
:type num_jobs: int
19+
:param verbose: verbose flag, defaults to False
20+
:type verbose: bool
21+
:param name: name value, defaults to None. If None, the name will be set to `PermutationTestDistanceBased`.
22+
:type name: Optional[str]
23+
24+
:Note:
25+
Callbacks logs are updated with the following variables:
26+
27+
- `observed_statistic`: observed statistic obtained from the distance-based detector. Same distance value returned by the `compare` method
28+
- `permutation_statistic`: list of statistics obtained from the permutations
29+
- `p_value`: p-value obtained from the permutation test
30+
31+
:Example:
32+
33+
>>> from frouros.callbacks import PermutationTestDistanceBased
34+
>>> from frouros.detectors.data_drift import MMD
35+
>>> import numpy as np
36+
>>> np.random.seed(seed=31)
37+
>>> X = np.random.multivariate_normal(mean=[1, 1], cov=[[2, 0], [0, 2]], size=100)
38+
>>> Y = np.random.multivariate_normal(mean=[0, 0], cov=[[2, 1], [1, 2]], size=100)
39+
>>> detector = MMD(callbacks=PermutationTestDistanceBased(num_permutations=1000, random_state=31))
40+
>>> _ = detector.fit(X=X)
41+
>>> distance, callbacks_log = detector.compare(X=Y)
42+
>>> distance
43+
DistanceResult(distance=0.05643613752975596)
44+
>>> callbacks_log["PermutationTestDistanceBased"]["p_value"]
45+
0.0
46+
""" # noqa: E501 # pylint: disable=line-too-long
47+
48+
def __init__( # noqa: D107
1649
self,
1750
num_permutations: int,
1851
num_jobs: int = -1,
19-
name: Optional[str] = None,
2052
verbose: bool = False,
53+
name: Optional[str] = None,
2154
**kwargs,
2255
) -> None:
23-
"""Init method.
24-
25-
:param num_permutations: number of permutations
26-
:type num_permutations: int
27-
:param num_jobs: number of jobs
28-
:type num_jobs: int
29-
:param name: name to be use
30-
:type name: Optional[str]
31-
"""
3256
super().__init__(name=name)
3357
self.num_permutations = num_permutations
3458
self.num_jobs = num_jobs

0 commit comments

Comments
 (0)