Skip to content

Commit f507e8b

Browse files
Add missing GeometricMovingAverage __init__ method
1 parent 8b14bc1 commit f507e8b

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

frouros/detectors/concept_drift/streaming/change_detection/geometric_moving_average.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Geometric Moving Average module."""
22

3+
from typing import Optional, List, Union
4+
5+
from frouros.callbacks.streaming.base import BaseCallbackStreaming
36
from frouros.detectors.concept_drift.streaming.change_detection.base import (
47
BaseCUSUM,
58
BaseCUSUMConfig,
@@ -40,8 +43,8 @@ def __init__( # noqa: D107
4043
class GeometricMovingAverage(BaseCUSUM):
4144
"""Geometric Moving Average [robertst1959control]_ detector.
4245
43-
:param config: configuration object of the detector
44-
:type config: GeometricMovingAverageConfig
46+
:param config: configuration object of the detector, defaults to None. If None, the default configuration of :class:`GeometricMovingAverageConfig` is used.
47+
:type config: Optional[GeometricMovingAverageConfig]
4548
:param callbacks: callbacks, defaults to None
4649
:type callbacks: Optional[Union[BaseCallbackStreaming, List[BaseCallbackStreaming]]]
4750
@@ -60,17 +63,29 @@ class GeometricMovingAverage(BaseCUSUM):
6063
>>> dist_a = np.random.normal(loc=0.2, scale=0.01, size=1000)
6164
>>> dist_b = np.random.normal(loc=0.8, scale=0.04, size=1000)
6265
>>> stream = np.concatenate((dist_a, dist_b))
63-
>>> detector = GeometricMovingAverage(config=GeometricMovingAverageConfig(lambda_=0.1))
66+
>>> detector = GeometricMovingAverage(config=GeometricMovingAverageConfig(lambda_=0.3))
6467
>>> for i, value in enumerate(stream):
6568
... _ = detector.update(value=value)
6669
... if detector.drift:
6770
... print(f"Change detected at index {i}")
6871
... break
69-
Change detected at index 1018
72+
Change detected at index 1071
7073
""" # noqa: E501
7174

7275
config_type = GeometricMovingAverageConfig # type: ignore
7376

77+
def __init__( # noqa: D107
78+
self,
79+
config: Optional[GeometricMovingAverageConfig] = None,
80+
callbacks: Optional[
81+
Union[BaseCallbackStreaming, List[BaseCallbackStreaming]]
82+
] = None,
83+
) -> None:
84+
super().__init__(
85+
config=config,
86+
callbacks=callbacks,
87+
)
88+
7489
def _update_sum(self, error_rate: float) -> None:
7590
self.sum_ = self.config.alpha * self.sum_ + ( # type: ignore
7691
1 - self.config.alpha # type: ignore

0 commit comments

Comments
 (0)