Skip to content

Commit 0853dbc

Browse files
Remove kwargs from __init__ method in statistical tests
1 parent 897946a commit 0853dbc

File tree

8 files changed

+97
-60
lines changed

8 files changed

+97
-60
lines changed

frouros/detectors/data_drift/batch/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def _specific_checks(self, X: np.ndarray) -> None: # noqa: N803
9797

9898
@abc.abstractmethod
9999
def _apply_method(
100-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
100+
self,
101+
X_ref: np.ndarray, # noqa: N803
102+
X: np.ndarray,
103+
**kwargs,
101104
) -> Any:
102105
pass
103106

@@ -110,9 +113,13 @@ def _compare(
110113
pass
111114

112115
def _get_result(
113-
self, X: np.ndarray, **kwargs # noqa: N803
116+
self,
117+
X: np.ndarray, # noqa: N803
118+
**kwargs,
114119
) -> Union[List[float], List[Tuple[float, float]], Tuple[float, float]]:
115120
result = self._apply_method( # type: ignore # pylint: disable=not-callable
116-
X_ref=self.X_ref, X=X, **kwargs
121+
X_ref=self.X_ref,
122+
X=X,
123+
**kwargs,
117124
)
118125
return result

frouros/detectors/data_drift/batch/statistical_test/anderson_darling.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@ class AndersonDarlingTest(BaseStatisticalTest):
1818
1919
:param callbacks: callbacks, defaults to None
2020
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
21-
:param kwargs: additional keyword arguments to pass to scipy.stats.anderson_ksamp
22-
:type kwargs: Dict[str, Any]
2321
2422
:Note:
25-
p-values are bounded between 0.001 and 0.25 according to scipy documentation [1]_.
23+
- Passing additional arguments to `scipy.stats.anderson_ksamp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.anderson_ksamp.html>`__ can be done using :func:`compare` kwargs.
24+
- p-values are bounded between 0.001 and 0.25 according to `scipy.stats.anderson_ksamp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.anderson_ksamp.html>`__.
2625
2726
:References:
2827
2928
.. [scholz1987k] Scholz, Fritz W., and Michael A. Stephens.
3029
"K-sample Anderson–Darling tests."
3130
Journal of the American Statistical Association 82.399 (1987): 918-924.
32-
[1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.anderson_ksamp.html # noqa: E501 # pylint: disable=line-too-long
3331
3432
:Example:
3533
@@ -42,29 +40,30 @@ class AndersonDarlingTest(BaseStatisticalTest):
4240
>>> _ = detector.fit(X=X)
4341
>>> detector.compare(X=Y)[0]
4442
StatisticalResult(statistic=32.40316586267425, p_value=0.001)
45-
"""
43+
""" # noqa: E501 # pylint: disable=line-too-long
4644

4745
def __init__( # noqa: D107
4846
self,
4947
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
50-
**kwargs,
5148
) -> None:
5249
super().__init__(
5350
data_type=NumericalData(),
5451
statistical_type=UnivariateData(),
5552
callbacks=callbacks,
5653
)
57-
self.kwargs = kwargs
5854

55+
@staticmethod
5956
def _statistical_test(
60-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
57+
X_ref: np.ndarray, # noqa: N803
58+
X: np.ndarray,
59+
**kwargs,
6160
) -> StatisticalResult:
6261
test = anderson_ksamp(
6362
samples=[
6463
X_ref,
6564
X,
6665
],
67-
**self.kwargs,
66+
**kwargs,
6867
)
6968
test = StatisticalResult(
7069
statistic=test.statistic,

frouros/detectors/data_drift/batch/statistical_test/base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@ class BaseStatisticalTest(BaseDataDriftBatch):
1515
"""Abstract class representing a statistical test."""
1616

1717
def _apply_method(
18-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
18+
self,
19+
X_ref: np.ndarray, # noqa: N803
20+
X: np.ndarray,
21+
**kwargs,
1922
) -> Tuple[float, float]:
20-
statistical_test = self._statistical_test(X_ref=X_ref, X=X, **kwargs)
23+
statistical_test = self._statistical_test(
24+
X_ref=X_ref,
25+
X=X,
26+
**kwargs,
27+
)
2128
return statistical_test
2229

2330
def _compare(
@@ -30,8 +37,11 @@ def _compare(
3037
result = self._get_result(X=X, **kwargs)
3138
return result # type: ignore
3239

40+
@staticmethod
3341
@abc.abstractmethod
3442
def _statistical_test(
35-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
36-
) -> Tuple[float, float]:
43+
X_ref: np.ndarray, # noqa: N803
44+
X: np.ndarray,
45+
**kwargs,
46+
) -> StatisticalResult:
3747
pass

frouros/detectors/data_drift/batch/statistical_test/chisquare.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ class ChiSquareTest(BaseStatisticalTest):
1919
2020
:param callbacks: callbacks, defaults to None
2121
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
22-
:param kwargs: additional keyword arguments to pass to scipy.stats.chi2_contingency
23-
:type kwargs: Dict[str, Any]
22+
23+
:Note:
24+
- Passing additional arguments to `scipy.stats.chi2_contingency <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chi2_contingency.html>`__ can be done using :func:`compare` kwargs.
2425
2526
:References:
2627
@@ -42,34 +43,43 @@ class ChiSquareTest(BaseStatisticalTest):
4243
>>> _ = detector.fit(X=X)
4344
>>> detector.compare(X=Y)[0]
4445
StatisticalResult(statistic=9.81474665685192, p_value=0.0017311812135839511)
45-
"""
46+
""" # noqa: E501 # pylint: disable=line-too-long
4647

4748
def __init__( # noqa: D107
4849
self,
4950
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
50-
**kwargs,
5151
) -> None:
5252
super().__init__(
5353
data_type=CategoricalData(),
5454
statistical_type=UnivariateData(),
5555
callbacks=callbacks,
5656
)
57-
self.kwargs = kwargs
5857

58+
@staticmethod
5959
def _statistical_test(
60-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
60+
X_ref: np.ndarray, # noqa: N803
61+
X: np.ndarray,
62+
**kwargs,
6163
) -> StatisticalResult:
62-
f_exp, f_obs = self._calculate_frequencies(X_ref=X_ref, X=X)
64+
f_exp, f_obs = ChiSquareTest._calculate_frequencies(
65+
X_ref=X_ref,
66+
X=X,
67+
)
6368
statistic, p_value, _, _ = chi2_contingency(
64-
observed=np.array([f_obs, f_exp]), **self.kwargs
69+
observed=np.array([f_obs, f_exp]),
70+
**kwargs,
6571
)
6672

67-
test = StatisticalResult(statistic=statistic, p_value=p_value)
73+
test = StatisticalResult(
74+
statistic=statistic,
75+
p_value=p_value,
76+
)
6877
return test
6978

7079
@staticmethod
7180
def _calculate_frequencies(
72-
X_ref: np.ndarray, X: np.ndarray # noqa: N803
81+
X_ref: np.ndarray, # noqa: N803
82+
X: np.ndarray,
7383
) -> Tuple[List[int], List[int]]:
7484
X_ref_counter, X_counter = [ # noqa: N806
7585
*map(collections.Counter, [X_ref, X]) # noqa: N806

frouros/detectors/data_drift/batch/statistical_test/cvm.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ class CVMTest(BaseStatisticalTest):
1919
2020
:param callbacks: callbacks, defaults to None
2121
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
22-
:param kwargs: additional keyword arguments to pass to scipy.stats.cramervonmises_2samp
23-
:type kwargs: Dict[str, Any]
22+
23+
:Note:
24+
- Passing additional arguments to `scipy.stats.cramervonmises_2samp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.cramervonmises_2samp.html>`__ can be done using :func:`compare` kwargs.
2425
2526
:References:
2627
@@ -39,19 +40,17 @@ class CVMTest(BaseStatisticalTest):
3940
>>> _ = detector.fit(X=X)
4041
>>> detector.compare(X=Y)[0]
4142
StatisticalResult(statistic=5.331699999999998, p_value=1.7705426014202885e-10)
42-
""" # noqa: E501
43+
""" # noqa: E501 # pylint: disable=line-too-long
4344

4445
def __init__( # noqa: D107
4546
self,
4647
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
47-
**kwargs,
4848
) -> None:
4949
super().__init__(
5050
data_type=NumericalData(),
5151
statistical_type=UnivariateData(),
5252
callbacks=callbacks,
5353
)
54-
self.kwargs = kwargs
5554

5655
@BaseStatisticalTest.X_ref.setter # type: ignore[attr-defined]
5756
def X_ref(self, value: Optional[np.ndarray]) -> None: # noqa: N802
@@ -75,13 +74,19 @@ def _check_sufficient_samples(X: np.ndarray) -> None: # noqa: N803
7574
if X.shape[0] < 2:
7675
raise InsufficientSamplesError("Number of samples must be at least 2.")
7776

77+
@staticmethod
7878
def _statistical_test(
79-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
79+
X_ref: np.ndarray, # noqa: N803
80+
X: np.ndarray,
81+
**kwargs,
8082
) -> StatisticalResult:
8183
test = cramervonmises_2samp(
8284
x=X_ref,
8385
y=X,
84-
method=self.kwargs.get("method", "auto"),
86+
**kwargs,
87+
)
88+
test = StatisticalResult(
89+
statistic=test.statistic,
90+
p_value=test.pvalue,
8591
)
86-
test = StatisticalResult(statistic=test.statistic, p_value=test.pvalue)
8792
return test

frouros/detectors/data_drift/batch/statistical_test/ks.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ class KSTest(BaseStatisticalTest):
1818
1919
:param callbacks: callbacks, defaults to None
2020
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
21-
:param kwargs: additional keyword arguments to pass to scipy.stats.ks_2samp
22-
:type kwargs: Dict[str, Any]
21+
22+
:Note:
23+
- Passing additional arguments to `scipy.stats.ks_2samp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html>`__ can be done using :func:`compare` kwargs.
2324
2425
:References:
2526
@@ -38,28 +39,32 @@ class KSTest(BaseStatisticalTest):
3839
>>> _ = detector.fit(X=X)
3940
>>> detector.compare(X=Y)[0]
4041
StatisticalResult(statistic=0.55, p_value=3.0406585087050305e-14)
41-
"""
42+
""" # noqa: E501 # pylint: disable=line-too-long
4243

4344
def __init__( # noqa: D107
4445
self,
4546
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
46-
**kwargs,
4747
) -> None:
4848
super().__init__(
4949
data_type=NumericalData(),
5050
statistical_type=UnivariateData(),
5151
callbacks=callbacks,
5252
)
53-
self.kwargs = kwargs
5453

54+
@staticmethod
5555
def _statistical_test(
56-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
56+
X_ref: np.ndarray, # noqa: N803
57+
X: np.ndarray,
58+
**kwargs,
5759
) -> StatisticalResult:
5860
test = ks_2samp(
5961
data1=X_ref,
6062
data2=X,
61-
alternative=self.kwargs.get("alternative", "two-sided"),
62-
method=self.kwargs.get("method", "auto"),
63+
alternative=kwargs.get("alternative", "two-sided"),
64+
method=kwargs.get("method", "auto"),
65+
)
66+
test = StatisticalResult(
67+
statistic=test.statistic,
68+
p_value=test.pvalue,
6369
)
64-
test = StatisticalResult(statistic=test.statistic, p_value=test.pvalue)
6570
return test

frouros/detectors/data_drift/batch/statistical_test/mann_whitney_u.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ class MannWhitneyUTest(BaseStatisticalTest):
1818
1919
:param callbacks: callbacks, defaults to None
2020
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
21-
:param kwargs: additional keyword arguments to pass to scipy.stats.mannwhitneyu
22-
:type kwargs: Dict[str, Any]
21+
22+
:Note:
23+
- Passing additional arguments to `scipy.stats.mannwhitneyu <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html>`__ can be done using :func:`compare` kwargs.
2324
2425
:References:
2526
@@ -39,29 +40,30 @@ class MannWhitneyUTest(BaseStatisticalTest):
3940
>>> _ = detector.fit(X=X)
4041
>>> detector.compare(X=Y)[0]
4142
StatisticalResult(statistic=2139.0, p_value=2.7623373527697943e-12)
42-
"""
43+
""" # noqa: E501 # pylint: disable=line-too-long
4344

4445
def __init__( # noqa: D107
4546
self,
4647
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
47-
**kwargs,
4848
) -> None:
4949
super().__init__(
5050
data_type=NumericalData(),
5151
statistical_type=UnivariateData(),
5252
callbacks=callbacks,
5353
)
54-
self.kwargs = kwargs
5554

55+
@staticmethod
5656
def _statistical_test(
57-
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
57+
X_ref: np.ndarray, # noqa: N803
58+
X: np.ndarray,
59+
**kwargs,
5860
) -> StatisticalResult:
5961
test = mannwhitneyu( # pylint: disable=unexpected-keyword-arg
6062
x=X_ref,
6163
y=X,
62-
alternative="two-sided",
63-
nan_policy="raise",
64-
**self.kwargs,
64+
alternative=kwargs.get("alternative", "two-sided"),
65+
nan_policy=kwargs.get("nan_policy", "raise"),
66+
**kwargs,
6567
)
6668
test = StatisticalResult(
6769
statistic=test.statistic,

frouros/detectors/data_drift/batch/statistical_test/welch_t_test.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ class WelchTTest(BaseStatisticalTest):
1818
1919
:param callbacks: callbacks, defaults to None
2020
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
21-
:param kwargs: additional keyword arguments to pass to scipy.stats.ttest_ind
22-
:type kwargs: Dict[str, Any]
21+
22+
:Note:
23+
- Passing additional arguments to `scipy.stats.ttest_ind <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_ind.html>`__ can be done using :func:`compare` kwargs.
2324
2425
:References:
2526
@@ -39,22 +40,20 @@ class WelchTTest(BaseStatisticalTest):
3940
>>> _ = detector.fit(X=X)
4041
>>> detector.compare(X=Y)[0]
4142
StatisticalResult(statistic=-7.651304662806378, p_value=8.685225410826823e-13)
42-
"""
43+
""" # noqa: E501 # pylint: disable=line-too-long
4344

4445
def __init__( # noqa: D107
4546
self,
4647
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
47-
**kwargs,
4848
) -> None:
4949
super().__init__(
5050
data_type=NumericalData(),
5151
statistical_type=UnivariateData(),
5252
callbacks=callbacks,
5353
)
54-
self.kwargs = kwargs
5554

55+
@staticmethod
5656
def _statistical_test(
57-
self,
5857
X_ref: np.ndarray, # noqa: N803
5958
X: np.ndarray,
6059
**kwargs,
@@ -63,8 +62,8 @@ def _statistical_test(
6362
a=X_ref,
6463
b=X,
6564
equal_var=False,
66-
alternative="two-sided",
67-
**self.kwargs,
65+
alternative=kwargs.get("alternative", "two-sided"),
66+
**kwargs,
6867
)
6968
test = StatisticalResult(
7069
statistic=test.statistic,

0 commit comments

Comments
 (0)