|
10 | 10 |
|
11 | 11 |
|
12 | 12 | class PermutationTestDistanceBased(BaseCallbackBatch): |
13 | | - """Permutation test on distance based batch callback class.""" |
14 | | - |
15 | | - def __init__( |
| 13 | + """Permutation test callback class that can be applied to :mod:`data_drift.batch.distance_based <frouros.detectors.data_drift.batch.distance_based>` detectors. |
| 14 | +
|
| 15 | + :param num_permutations: number of permutations to obtain the p-value |
| 16 | + :type num_permutations: int |
| 17 | + :param num_jobs: number of jobs, defaults to -1 |
| 18 | + :type num_jobs: int |
| 19 | + :param verbose: verbose flag, defaults to False |
| 20 | + :type verbose: bool |
| 21 | + :param name: name value, defaults to None. If None, the name will be set to `PermutationTestDistanceBased`. |
| 22 | + :type name: Optional[str] |
| 23 | + |
| 24 | + :Note: |
| 25 | + Callbacks logs are updated with the following variables: |
| 26 | + |
| 27 | + - `observed_statistic`: observed statistic obtained from the distance-based detector. Same distance value returned by the `compare` method |
| 28 | + - `permutation_statistic`: list of statistics obtained from the permutations |
| 29 | + - `p_value`: p-value obtained from the permutation test |
| 30 | +
|
| 31 | + :Example: |
| 32 | +
|
| 33 | + >>> from frouros.callbacks import PermutationTestDistanceBased |
| 34 | + >>> from frouros.detectors.data_drift import MMD |
| 35 | + >>> import numpy as np |
| 36 | + >>> np.random.seed(seed=31) |
| 37 | + >>> X = np.random.multivariate_normal(mean=[1, 1], cov=[[2, 0], [0, 2]], size=100) |
| 38 | + >>> Y = np.random.multivariate_normal(mean=[0, 0], cov=[[2, 1], [1, 2]], size=100) |
| 39 | + >>> detector = MMD(callbacks=PermutationTestDistanceBased(num_permutations=1000, random_state=31)) |
| 40 | + >>> _ = detector.fit(X=X) |
| 41 | + >>> distance, callbacks_log = detector.compare(X=Y) |
| 42 | + >>> distance |
| 43 | + DistanceResult(distance=0.05643613752975596) |
| 44 | + >>> callbacks_log["PermutationTestDistanceBased"]["p_value"] |
| 45 | + 0.0 |
| 46 | + """ # noqa: E501 # pylint: disable=line-too-long |
| 47 | + |
| 48 | + def __init__( # noqa: D107 |
16 | 49 | self, |
17 | 50 | num_permutations: int, |
18 | 51 | num_jobs: int = -1, |
19 | | - name: Optional[str] = None, |
20 | 52 | verbose: bool = False, |
| 53 | + name: Optional[str] = None, |
21 | 54 | **kwargs, |
22 | 55 | ) -> None: |
23 | | - """Init method. |
24 | | -
|
25 | | - :param num_permutations: number of permutations |
26 | | - :type num_permutations: int |
27 | | - :param num_jobs: number of jobs |
28 | | - :type num_jobs: int |
29 | | - :param name: name to be use |
30 | | - :type name: Optional[str] |
31 | | - """ |
32 | 56 | super().__init__(name=name) |
33 | 57 | self.num_permutations = num_permutations |
34 | 58 | self.num_jobs = num_jobs |
|
0 commit comments