Skip to content

Commit e6619b8

Browse files
Adapt existing callbacks to non using kwargs
1 parent 933a17d commit e6619b8

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

frouros/callbacks/batch/permutation_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,22 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
148148
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
149149
return permuted_statistic, p_value
150150

151-
def on_compare_end(self, **kwargs) -> None:
152-
"""On compare end method."""
153-
X_ref, X_test = kwargs["X_ref"], kwargs["X_test"] # noqa: N806
154-
observed_statistic = kwargs["result"][0]
151+
def on_compare_end(
152+
self,
153+
result: Any,
154+
X_ref: np.ndarray, # noqa: N803
155+
X_test: np.ndarray,
156+
) -> None:
157+
"""On compare end method.
158+
159+
:param result: result obtained from the `compare` method
160+
:type result: Any
161+
:param X_ref: reference data
162+
:type X_ref: numpy.ndarray
163+
:param X_test: test data
164+
:type X_test: numpy.ndarray
165+
"""
166+
observed_statistic = result.distance
155167
permuted_statistics, p_value = self._calculate_p_value(
156168
X_ref=X_ref,
157169
X_test=X_test,

frouros/callbacks/batch/reset.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Reset batch callback module."""
22

3-
from typing import Optional
3+
from typing import Any, Optional
4+
5+
import numpy as np # type: ignore
46

57
from frouros.callbacks.batch.base import BaseCallbackBatch
68
from frouros.utils.logger import logger
@@ -58,10 +60,23 @@ def alpha(self, value: float) -> None:
5860
raise ValueError("value must be greater than 0.")
5961
self._alpha = value
6062

61-
def on_compare_end(self, **kwargs) -> None:
62-
"""On compare end method."""
63-
p_value = kwargs["result"].p_value
64-
if p_value < self.alpha:
63+
def on_compare_end(
64+
self,
65+
result: Any,
66+
X_ref: np.ndarray, # noqa: N803
67+
X_test: np.ndarray,
68+
) -> None:
69+
"""On compare end method.
70+
71+
:param result: result obtained from the `compare` method
72+
:type result: Any
73+
:param X_ref: reference data
74+
:type X_ref: numpy.ndarray
75+
:param X_test: test data
76+
:type X_test: numpy.ndarray
77+
"""
78+
p_value = result.p_value
79+
if p_value <= self.alpha:
6580
logger.info("Drift detected. Resetting detector...")
6681
self.detector.reset() # type: ignore
6782

frouros/callbacks/streaming/history.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""History callback module."""
22

3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, Union
44

55
from frouros.callbacks.streaming.base import BaseCallbackStreaming
66
from frouros.utils.stats import BaseStat
@@ -62,9 +62,13 @@ def add_additional_vars(self, vars_: List[str]) -> None:
6262
self.additional_vars.extend(vars_)
6363
self.history = {**self.history, **{var: [] for var in self.additional_vars}}
6464

65-
def on_update_end(self, **kwargs) -> None:
66-
"""On update end method."""
67-
self.history["value"].append(kwargs["value"])
65+
def on_update_end(self, value: Union[int, float]) -> None:
66+
"""On update end method.
67+
68+
:param value: value used to update the detector
69+
:type value: Union[int, float]
70+
"""
71+
self.history["value"].append(value)
6872
self.history["num_instances"].append(
6973
self.detector.num_instances # type: ignore
7074
)

0 commit comments

Comments
 (0)