Skip to content

Commit 677efb0

Browse files
authored
Type hints for the module conf (#1719)
* Remove MyPy exceptions for the conf module * Make Jackknife generic over the wrapped type It' a wrapper after all, so it wraps a model. * Add types to all the methods * Describe the overload of predict_one * Add trivial return types to RollingQuantile
1 parent b6087fc commit 677efb0

File tree

4 files changed

+39
-13
lines changed

4 files changed

+39
-13
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ module = [
210210
"river.imblearn.*",
211211
"river.feature_selection.*",
212212
"river.active.*",
213-
"river.conf.*",
214213
"river.neural_net.*",
215214
"river.test_estimators",
216215
"river.dummy",

river/conf/interval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ class Interval:
2424
upper: float
2525

2626
@property
27-
def center(self):
27+
def center(self) -> float:
2828
"""The center of the interval."""
2929
return (self.lower + self.upper) / 2
3030

3131
@property
32-
def width(self):
32+
def width(self) -> float:
3333
"""The width of the interval."""
3434
return self.upper - self.lower
3535

36-
def __contains__(self, x):
36+
def __contains__(self, x: float) -> bool:
3737
return self.lower <= x <= self.upper

river/conf/jackknife.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

3-
from river import base, stats
3+
from collections.abc import Iterator
4+
from typing import Literal, TypeVar, overload
5+
6+
from river import base, compose, stats
47

58
from . import interval
69

10+
T = TypeVar("T", bound=base.Regressor)
11+
712

8-
class RegressionJackknife(base.Wrapper, base.Regressor):
13+
class RegressionJackknife(base.Wrapper[T], base.Regressor):
914
"""Jackknife method for regression.
1015
1116
This is a conformal prediction method for regression. It is based on the jackknife method. The
@@ -81,7 +86,7 @@ class RegressionJackknife(base.Wrapper, base.Regressor):
8186

8287
def __init__(
8388
self,
84-
regressor: base.Regressor,
89+
regressor: T,
8590
confidence_level: float = 0.95,
8691
window_size: int | None = None,
8792
):
@@ -100,24 +105,46 @@ def __init__(
100105
)
101106

102107
@property
103-
def _wrapped_model(self):
108+
def _wrapped_model(self) -> T:
104109
return self.regressor
105110

106111
@classmethod
107-
def _unit_test_params(cls):
112+
def _unit_test_params(cls) -> Iterator[dict[str, compose.Pipeline]]:
108113
from river import linear_model, preprocessing
109114

110115
yield {"regressor": (preprocessing.StandardScaler() | linear_model.LinearRegression())}
111116

112-
def learn_one(self, x, y, **kwargs):
117+
def learn_one(
118+
self, x: dict[base.typing.FeatureName, object], y: base.typing.RegTarget, **kwargs: object
119+
) -> None:
113120
# Update the quantiles
114121
error = y - self.regressor.predict_one(x)
115122
self._lower.update(error)
116123
self._upper.update(error)
117124

118125
self.regressor.learn_one(x, y, **kwargs)
119126

120-
def predict_one(self, x, with_interval=False, **kwargs):
127+
@overload
128+
def predict_one(
129+
self,
130+
x: dict[base.typing.FeatureName, object],
131+
with_interval: Literal[False] = False,
132+
**kwargs: object,
133+
) -> float: ...
134+
@overload
135+
def predict_one(
136+
self,
137+
x: dict[base.typing.FeatureName, object],
138+
with_interval: Literal[True],
139+
**kwargs: object,
140+
) -> interval.Interval: ...
141+
142+
def predict_one(
143+
self,
144+
x: dict[base.typing.FeatureName, object],
145+
with_interval: bool = False,
146+
**kwargs: object,
147+
) -> float | interval.Interval:
121148
"""Predict the output of features `x`.
122149
123150
Parameters

river/stats/quantile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def __init__(self, q: float, window_size: int):
136136
self.window_size_value = window_size
137137
self._is_updated = False
138138

139-
def update(self, x):
139+
def update(self, x) -> None:
140140
self._rolling_quantile.update(x)
141141
if not self._is_updated:
142142
self._is_updated = True
143143

144-
def get(self):
144+
def get(self) -> float | None:
145145
if not self._is_updated:
146146
return None
147147
return self._rolling_quantile.get()

0 commit comments

Comments
 (0)