Skip to content

Commit 933a17d

Browse files
Remove kwargs to be passed to callbacks. Set fixed variables instead
1 parent 028bb22 commit 933a17d

File tree

7 files changed

+74
-29
lines changed

7 files changed

+74
-29
lines changed

frouros/callbacks/base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import abc
44
from typing import Optional
55

6+
import numpy as np # type: ignore
7+
68

79
class BaseCallback(abc.ABC):
810
"""Abstract class representing a callback."""
@@ -55,11 +57,19 @@ def set_detector(self, detector) -> None:
5557
# )
5658
# self._detector = value
5759

58-
def on_fit_start(self, **kwargs) -> None:
59-
"""On fit start method."""
60+
def on_fit_start(self, X: np.ndarray) -> None: # noqa: N803
61+
"""On fit start method.
62+
63+
:param X: reference data
64+
:type X: numpy.ndarray
65+
"""
6066

61-
def on_fit_end(self, **kwargs) -> None:
62-
"""On fit end method."""
67+
def on_fit_end(self, X: np.ndarray) -> None: # noqa: N803
68+
"""On fit end method.
69+
70+
:param X: reference data
71+
:type X: numpy.ndarray
72+
"""
6373

6474
@abc.abstractmethod
6575
def reset(self) -> None:

frouros/callbacks/batch/base.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,44 @@
11
"""Base callback batch module."""
22

33
import abc
4+
from typing import Any
5+
6+
import numpy as np # type: ignore
47

58
from frouros.callbacks.base import BaseCallback
69

710

811
class BaseCallbackBatch(BaseCallback):
912
"""Callback batch class."""
1013

11-
def on_compare_start(self, **kwargs) -> None:
12-
"""On compare start method."""
13-
14-
def on_compare_end(self, **kwargs) -> None:
15-
"""On compare end method."""
14+
def on_compare_start(
15+
self,
16+
X_ref: np.ndarray, # noqa: N803
17+
X_test: np.ndarray,
18+
) -> None:
19+
"""On compare start method.
20+
21+
:param X_ref: reference data
22+
:type X_ref: numpy.ndarray
23+
:param X_test: test data
24+
:type X_test: numpy.ndarray
25+
"""
26+
27+
def on_compare_end(
28+
self,
29+
result: Any,
30+
X_ref: np.ndarray, # noqa: N803
31+
X_test: np.ndarray,
32+
) -> None:
33+
"""On compare end method.
34+
35+
:param result: result obtained from the `compare` method
36+
:type result: Any
37+
:param X_ref: reference data
38+
:type X_ref: numpy.ndarray
39+
:param X_test: test data
40+
:type X_test: numpy.ndarray
41+
"""
1642

1743
# FIXME: set_detector method as a workaround to # pylint: disable=fixme
1844
# avoid circular import problem. Make it an abstract method and

frouros/callbacks/streaming/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
"""Base callback streaming module."""
22

33
import abc
4+
from typing import Union
45

56
from frouros.callbacks.base import BaseCallback
67

78

89
class BaseCallbackStreaming(BaseCallback):
910
"""Callback streaming class."""
1011

11-
def on_update_start(self, **kwargs) -> None:
12-
"""On update start method."""
12+
def on_update_start(self, value: Union[int, float]) -> None:
13+
"""On update start method.
1314
14-
def on_update_end(self, **kwargs) -> None:
15-
"""On update end method."""
15+
:param value: value used to update the detector
16+
:type value: Union[int, float]
17+
"""
18+
19+
def on_update_end(self, value: Union[int, float]) -> None:
20+
"""On update end method.
21+
22+
:param value: value used to update the detector
23+
:type value: Union[int, float]
24+
"""
1625

1726
# FIXME: set_detector method as a workaround to # pylint: disable=fixme
1827
# avoid circular import problem. Make it an abstract method and

frouros/detectors/concept_drift/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,17 @@ def update(self, value: Union[int, float], **kwargs) -> Dict[str, Any]:
184184
185185
:param value: value to update detector
186186
:type value: Union[int, float]
187+
:return: callbacks logs
188+
:rtype: Dict[str, Any]]
187189
"""
188190
for callback in self.callbacks: # type: ignore
189191
callback.on_update_start( # type: ignore
190192
value=value,
191-
**kwargs,
192193
)
193194
self._update(value=value, **kwargs)
194195
for callback in self.callbacks: # type: ignore
195196
callback.on_update_end( # type: ignore
196197
value=value,
197-
**kwargs,
198198
)
199199

200200
callbacks_logs = self._get_callbacks_logs()

frouros/detectors/data_drift/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,15 @@ def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
187187
for callback in self.callbacks: # type: ignore
188188
callback.on_fit_start(
189189
X=X,
190-
**kwargs,
191190
)
192191
self._fit(X=X, **kwargs)
193192
for callback in self.callbacks: # type: ignore
194193
callback.on_fit_end(
195194
X=X,
196-
**kwargs,
197195
)
198196

199-
logs = self._get_callbacks_logs()
200-
return logs
197+
callbacks_logs = self._get_callbacks_logs()
198+
return callbacks_logs
201199

202200
def reset(self) -> None:
203201
"""Reset method."""

frouros/detectors/data_drift/batch/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def compare(
6060
) -> Tuple[np.ndarray, Dict[str, Any]]:
6161
"""Compare values.
6262
63-
:param X: feature data
63+
:param X: test data
6464
:type X: numpy.ndarray
6565
:return: compare result and callbacks logs
6666
:rtype: Tuple[numpy.ndarray, Dict[str, Any]]

frouros/detectors/data_drift/streaming/base.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,29 @@ def reset(self) -> None:
5656
self._reset()
5757

5858
def update(
59-
self, value: Union[int, float]
59+
self,
60+
value: Union[int, float],
6061
) -> Tuple[Optional[BaseResult], Dict[str, Any]]:
6162
"""Update detector.
6263
6364
:param value: value to use to update the detector
6465
:type value: Union[int, float]
65-
:return: update result
66-
:rtype: Optional[BaseResult]
66+
:return: update result and callbacks logs
67+
:rtype: Tuple[Optional[BaseResult], Dict[str, Any]]
6768
"""
6869
self._common_checks() # noqa: N806
6970
self._specific_checks(X=value) # noqa: N806
7071
self.num_instances += 1
7172

7273
for callback in self.callbacks: # type: ignore
73-
callback.on_update_start(value=value) # type: ignore
74+
callback.on_update_start( # type: ignore
75+
value=value, # type: ignore
76+
)
7477
result = self._update(value=value)
75-
if result is not None:
76-
for callback in self.callbacks: # type: ignore
77-
callback.on_update_end( # type: ignore
78-
value=result.distance, # type: ignore
79-
)
78+
for callback in self.callbacks: # type: ignore
79+
callback.on_update_end( # type: ignore
80+
value=result, # type: ignore
81+
)
8082

8183
callbacks_logs = self._get_callbacks_logs()
8284
return result, callbacks_logs

0 commit comments

Comments
 (0)