Skip to content

Commit 3d6157e

Browse files
committed
refactor(pyargus): define a runtime checkable Signal protocol
1 parent 8027f86 commit 3d6157e

File tree

6 files changed

+70
-17
lines changed

6 files changed

+70
-17
lines changed

noxfile.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,16 @@ def ruff(session: nox.Session):
8989
def mypy(session: nox.Session):
9090
session.conda_install("mypy", "typing-extensions", "pytest", "hypothesis", "numpy")
9191
session.env.update(ENV)
92+
9293
with session.chdir(CURRENT_DIR / "pyargus"):
9394
session.install("-e", ".")
9495
session.run("mypy", ".")
95-
session.run("stubtest", "argus")
96+
session.run(
97+
"stubtest",
98+
"argus",
99+
"--allowlist",
100+
str(CURRENT_DIR / "pyargus/stubtest_allow.txt"),
101+
)
96102

97103

98104
@nox.session

pyargus/argus/_argus.pyi

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import ClassVar, Generic, Protocol, TypeVar, final
1+
from typing import ClassVar, Protocol, final
22

33
from typing_extensions import Self
44

@@ -134,9 +134,7 @@ class dtype: # noqa: N801
134134
def __eq__(self, other: object) -> bool: ...
135135
def __int__(self) -> int: ...
136136

137-
_SignalKind = TypeVar("_SignalKind", bool, int, float, covariant=True)
138-
139-
class Signal(Generic[_SignalKind], Protocol):
137+
class Signal:
140138
def is_empty(self) -> bool: ...
141139
@property
142140
def start_time(self) -> float | None: ...
@@ -146,16 +144,16 @@ class Signal(Generic[_SignalKind], Protocol):
146144
def kind(self) -> dtype: ...
147145

148146
@final
149-
class BoolSignal(Signal[bool]):
147+
class BoolSignal(Signal):
150148
@classmethod
151149
def constant(cls, value: bool) -> Self: ...
152150
@classmethod
153151
def from_samples(cls, samples: list[tuple[float, bool]]) -> Self: ...
154152
def push(self, time: float, value: bool) -> None: ...
155-
def at(self, time: float) -> _SignalKind | None: ...
153+
def at(self, time: float) -> bool | None: ...
156154

157155
@final
158-
class IntSignal(Signal[int]):
156+
class IntSignal(Signal):
159157
@classmethod
160158
def constant(cls, value: int) -> Self: ...
161159
@classmethod
@@ -164,7 +162,7 @@ class IntSignal(Signal[int]):
164162
def at(self, time: float) -> int | None: ...
165163

166164
@final
167-
class UnsignedIntSignal(Signal[int]):
165+
class UnsignedIntSignal(Signal):
168166
@classmethod
169167
def constant(cls, value: int) -> Self: ...
170168
@classmethod
@@ -173,7 +171,7 @@ class UnsignedIntSignal(Signal[int]):
173171
def at(self, time: float) -> int | None: ...
174172

175173
@final
176-
class FloatSignal(Signal[float]):
174+
class FloatSignal(Signal):
177175
@classmethod
178176
def constant(cls, value: float) -> Self: ...
179177
@classmethod

pyargus/argus/signals.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,43 @@
1-
from argus._argus import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
1+
from typing import List, Optional, Protocol, Tuple, TypeVar, runtime_checkable
2+
3+
from typing_extensions import Self
4+
5+
from argus._argus import BoolSignal, FloatSignal, IntSignal, UnsignedIntSignal, dtype
6+
7+
T = TypeVar("T", bool, int, float)
8+
9+
10+
@runtime_checkable
11+
class Signal(Protocol[T]):
12+
def is_empty(self) -> bool:
13+
...
14+
15+
@property
16+
def start_time(self) -> Optional[float]:
17+
...
18+
19+
@property
20+
def end_time(self) -> Optional[float]:
21+
...
22+
23+
@property
24+
def kind(self) -> dtype:
25+
...
26+
27+
@classmethod
28+
def constant(cls, value: T) -> Self:
29+
...
30+
31+
@classmethod
32+
def from_samples(cls, samples: List[Tuple[float, T]]) -> Self:
33+
...
34+
35+
def push(self, time: float, value: T) -> None:
36+
...
37+
38+
def at(self, time: float) -> Optional[T]:
39+
...
40+
241

342
__all__ = [
443
"Signal",

pyargus/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ addopts = ["--import-mode=importlib"]
4141
testpaths = ["tests"]
4242

4343
[tool.mypy]
44+
packages = ["argus"]
4445
# ignore_missing_imports = true
4546
show_error_codes = true
4647
plugins = ["numpy.typing.mypy_plugin"]

pyargus/stubtest_allow.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
argus.signals.Protocol
2+
argus.signals.TypeVar.__bound__
3+
argus.signals.TypeVar.__constraints__
4+
argus.signals.TypeVar.__contravariant__
5+
argus.signals.TypeVar.__covariant__

pyargus/tests/test_signals.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
9595
def test_correct_constant_signals(data: st.DataObject) -> None:
9696
dtype_ = data.draw(gen_dtype())
9797
signal = data.draw(constant_signal(dtype_))
98+
assert isinstance(signal, argus.Signal)
9899

99100
assert not signal.is_empty()
100101
assert signal.start_time is None
@@ -108,6 +109,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
108109

109110
note(f"Samples: {gen_samples}")
110111
signal = argus.signal(dtype_, data=xs)
112+
assert isinstance(signal, argus.Signal)
111113
if len(xs) > 0:
112114
expected_start_time = xs[0][0]
113115
expected_end_time = xs[-1][0]
@@ -165,18 +167,20 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
165167
@given(st.data())
166168
def test_push_to_empty_signal(data: st.DataObject) -> None:
167169
dtype_ = data.draw(gen_dtype())
168-
sig = data.draw(empty_signal(dtype_=dtype_))
169-
assert sig.is_empty()
170+
signal = data.draw(empty_signal(dtype_=dtype_))
171+
assert isinstance(signal, argus.Signal)
172+
assert signal.is_empty()
170173
element = data.draw(gen_element_fn(dtype_))
171174
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
172-
sig.push(0.0, element) # type: ignore[attr-defined]
175+
signal.push(0.0, element) # type: ignore[attr-defined]
173176

174177

175178
@given(st.data())
176179
def test_push_to_constant_signal(data: st.DataObject) -> None:
177180
dtype_ = data.draw(gen_dtype())
178-
sig = data.draw(constant_signal(dtype_=dtype_))
179-
assert not sig.is_empty()
181+
signal = data.draw(constant_signal(dtype_=dtype_))
182+
assert isinstance(signal, argus.Signal)
183+
assert not signal.is_empty()
180184
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
181185
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
182-
sig.push(*sample) # type: ignore[attr-defined]
186+
signal.push(*sample) # type: ignore[attr-defined]

0 commit comments

Comments
 (0)