Skip to content

Commit 897946a

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #261 from IFCA/fix-remove-callbacks-kwargs
Fix remove callbacks kwargs
2 parents 028bb22 + e6619b8 commit 897946a

File tree

10 files changed

+118
-42
lines changed

10 files changed

+118
-42
lines changed

frouros/callbacks/base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import abc
44
from typing import Optional
55

6+
import numpy as np # type: ignore
7+
68

79
class BaseCallback(abc.ABC):
810
"""Abstract class representing a callback."""
@@ -55,11 +57,19 @@ def set_detector(self, detector) -> None:
5557
# )
5658
# self._detector = value
5759

58-
def on_fit_start(self, **kwargs) -> None:
59-
"""On fit start method."""
60+
def on_fit_start(self, X: np.ndarray) -> None: # noqa: N803
61+
"""On fit start method.
62+
63+
:param X: reference data
64+
:type X: numpy.ndarray
65+
"""
6066

61-
def on_fit_end(self, **kwargs) -> None:
62-
"""On fit end method."""
67+
def on_fit_end(self, X: np.ndarray) -> None: # noqa: N803
68+
"""On fit end method.
69+
70+
:param X: reference data
71+
:type X: numpy.ndarray
72+
"""
6373

6474
@abc.abstractmethod
6575
def reset(self) -> None:

frouros/callbacks/batch/base.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,44 @@
11
"""Base callback batch module."""
22

33
import abc
4+
from typing import Any
5+
6+
import numpy as np # type: ignore
47

58
from frouros.callbacks.base import BaseCallback
69

710

811
class BaseCallbackBatch(BaseCallback):
912
"""Callback batch class."""
1013

11-
def on_compare_start(self, **kwargs) -> None:
12-
"""On compare start method."""
13-
14-
def on_compare_end(self, **kwargs) -> None:
15-
"""On compare end method."""
14+
def on_compare_start(
15+
self,
16+
X_ref: np.ndarray, # noqa: N803
17+
X_test: np.ndarray,
18+
) -> None:
19+
"""On compare start method.
20+
21+
:param X_ref: reference data
22+
:type X_ref: numpy.ndarray
23+
:param X_test: test data
24+
:type X_test: numpy.ndarray
25+
"""
26+
27+
def on_compare_end(
28+
self,
29+
result: Any,
30+
X_ref: np.ndarray, # noqa: N803
31+
X_test: np.ndarray,
32+
) -> None:
33+
"""On compare end method.
34+
35+
:param result: result obtained from the `compare` method
36+
:type result: Any
37+
:param X_ref: reference data
38+
:type X_ref: numpy.ndarray
39+
:param X_test: test data
40+
:type X_test: numpy.ndarray
41+
"""
1642

1743
# FIXME: set_detector method as a workaround to # pylint: disable=fixme
1844
# avoid circular import problem. Make it an abstract method and

frouros/callbacks/batch/permutation_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,22 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
148148
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
149149
return permuted_statistic, p_value
150150

151-
def on_compare_end(self, **kwargs) -> None:
152-
"""On compare end method."""
153-
X_ref, X_test = kwargs["X_ref"], kwargs["X_test"] # noqa: N806
154-
observed_statistic = kwargs["result"][0]
151+
def on_compare_end(
152+
self,
153+
result: Any,
154+
X_ref: np.ndarray, # noqa: N803
155+
X_test: np.ndarray,
156+
) -> None:
157+
"""On compare end method.
158+
159+
:param result: result obtained from the `compare` method
160+
:type result: Any
161+
:param X_ref: reference data
162+
:type X_ref: numpy.ndarray
163+
:param X_test: test data
164+
:type X_test: numpy.ndarray
165+
"""
166+
observed_statistic = result.distance
155167
permuted_statistics, p_value = self._calculate_p_value(
156168
X_ref=X_ref,
157169
X_test=X_test,

frouros/callbacks/batch/reset.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Reset batch callback module."""
22

3-
from typing import Optional
3+
from typing import Any, Optional
4+
5+
import numpy as np # type: ignore
46

57
from frouros.callbacks.batch.base import BaseCallbackBatch
68
from frouros.utils.logger import logger
@@ -58,10 +60,23 @@ def alpha(self, value: float) -> None:
5860
raise ValueError("value must be greater than 0.")
5961
self._alpha = value
6062

61-
def on_compare_end(self, **kwargs) -> None:
62-
"""On compare end method."""
63-
p_value = kwargs["result"].p_value
64-
if p_value < self.alpha:
63+
def on_compare_end(
64+
self,
65+
result: Any,
66+
X_ref: np.ndarray, # noqa: N803
67+
X_test: np.ndarray,
68+
) -> None:
69+
"""On compare end method.
70+
71+
:param result: result obtained from the `compare` method
72+
:type result: Any
73+
:param X_ref: reference data
74+
:type X_ref: numpy.ndarray
75+
:param X_test: test data
76+
:type X_test: numpy.ndarray
77+
"""
78+
p_value = result.p_value
79+
if p_value <= self.alpha:
6580
logger.info("Drift detected. Resetting detector...")
6681
self.detector.reset() # type: ignore
6782

frouros/callbacks/streaming/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
"""Base callback streaming module."""
22

33
import abc
4+
from typing import Union
45

56
from frouros.callbacks.base import BaseCallback
67

78

89
class BaseCallbackStreaming(BaseCallback):
910
"""Callback streaming class."""
1011

11-
def on_update_start(self, **kwargs) -> None:
12-
"""On update start method."""
12+
def on_update_start(self, value: Union[int, float]) -> None:
13+
"""On update start method.
1314
14-
def on_update_end(self, **kwargs) -> None:
15-
"""On update end method."""
15+
:param value: value used to update the detector
16+
:type value: Union[int, float]
17+
"""
18+
19+
def on_update_end(self, value: Union[int, float]) -> None:
20+
"""On update end method.
21+
22+
:param value: value used to update the detector
23+
:type value: Union[int, float]
24+
"""
1625

1726
# FIXME: set_detector method as a workaround to # pylint: disable=fixme
1827
# avoid circular import problem. Make it an abstract method and

frouros/callbacks/streaming/history.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""History callback module."""
22

3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, Union
44

55
from frouros.callbacks.streaming.base import BaseCallbackStreaming
66
from frouros.utils.stats import BaseStat
@@ -62,9 +62,13 @@ def add_additional_vars(self, vars_: List[str]) -> None:
6262
self.additional_vars.extend(vars_)
6363
self.history = {**self.history, **{var: [] for var in self.additional_vars}}
6464

65-
def on_update_end(self, **kwargs) -> None:
66-
"""On update end method."""
67-
self.history["value"].append(kwargs["value"])
65+
def on_update_end(self, value: Union[int, float]) -> None:
66+
"""On update end method.
67+
68+
:param value: value used to update the detector
69+
:type value: Union[int, float]
70+
"""
71+
self.history["value"].append(value)
6872
self.history["num_instances"].append(
6973
self.detector.num_instances # type: ignore
7074
)

frouros/detectors/concept_drift/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,17 @@ def update(self, value: Union[int, float], **kwargs) -> Dict[str, Any]:
184184
185185
:param value: value to update detector
186186
:type value: Union[int, float]
187+
:return: callbacks logs
188+
:rtype: Dict[str, Any]]
187189
"""
188190
for callback in self.callbacks: # type: ignore
189191
callback.on_update_start( # type: ignore
190192
value=value,
191-
**kwargs,
192193
)
193194
self._update(value=value, **kwargs)
194195
for callback in self.callbacks: # type: ignore
195196
callback.on_update_end( # type: ignore
196197
value=value,
197-
**kwargs,
198198
)
199199

200200
callbacks_logs = self._get_callbacks_logs()

frouros/detectors/data_drift/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,15 @@ def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
187187
for callback in self.callbacks: # type: ignore
188188
callback.on_fit_start(
189189
X=X,
190-
**kwargs,
191190
)
192191
self._fit(X=X, **kwargs)
193192
for callback in self.callbacks: # type: ignore
194193
callback.on_fit_end(
195194
X=X,
196-
**kwargs,
197195
)
198196

199-
logs = self._get_callbacks_logs()
200-
return logs
197+
callbacks_logs = self._get_callbacks_logs()
198+
return callbacks_logs
201199

202200
def reset(self) -> None:
203201
"""Reset method."""

frouros/detectors/data_drift/batch/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def compare(
6060
) -> Tuple[np.ndarray, Dict[str, Any]]:
6161
"""Compare values.
6262
63-
:param X: feature data
63+
:param X: test data
6464
:type X: numpy.ndarray
6565
:return: compare result and callbacks logs
6666
:rtype: Tuple[numpy.ndarray, Dict[str, Any]]

frouros/detectors/data_drift/streaming/base.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,29 @@ def reset(self) -> None:
5656
self._reset()
5757

5858
def update(
59-
self, value: Union[int, float]
59+
self,
60+
value: Union[int, float],
6061
) -> Tuple[Optional[BaseResult], Dict[str, Any]]:
6162
"""Update detector.
6263
6364
:param value: value to use to update the detector
6465
:type value: Union[int, float]
65-
:return: update result
66-
:rtype: Optional[BaseResult]
66+
:return: update result and callbacks logs
67+
:rtype: Tuple[Optional[BaseResult], Dict[str, Any]]
6768
"""
6869
self._common_checks() # noqa: N806
6970
self._specific_checks(X=value) # noqa: N806
7071
self.num_instances += 1
7172

7273
for callback in self.callbacks: # type: ignore
73-
callback.on_update_start(value=value) # type: ignore
74+
callback.on_update_start( # type: ignore
75+
value=value, # type: ignore
76+
)
7477
result = self._update(value=value)
75-
if result is not None:
76-
for callback in self.callbacks: # type: ignore
77-
callback.on_update_end( # type: ignore
78-
value=result.distance, # type: ignore
79-
)
78+
for callback in self.callbacks: # type: ignore
79+
callback.on_update_end( # type: ignore
80+
value=result, # type: ignore
81+
)
8082

8183
callbacks_logs = self._get_callbacks_logs()
8284
return result, callbacks_logs

0 commit comments

Comments
 (0)