Skip to content

Commit e2cfe3d

Browse files
committed
refactor(pyargus): data type name
1 parent 8093ab7 commit e2cfe3d

File tree

5 files changed

+86
-82
lines changed

5 files changed

+86
-82
lines changed

pyargus/argus/__init__.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Optional, Tuple, Type, Union
44

55
from argus import _argus
6-
from argus._argus import DType as DType
6+
from argus._argus import dtype
77
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
88
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
99

@@ -15,25 +15,19 @@
1515
AllowedDtype = Union[bool, int, float]
1616

1717

18-
def declare_var(name: str, dtype: Union[DType, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
18+
def declare_var(name: str, dtype_: Union[dtype, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
1919
"""Declare a variable with the given name and type"""
20-
if isinstance(dtype, type):
21-
if dtype == bool:
22-
dtype = DType.Bool
23-
elif dtype == int:
24-
dtype = DType.Int
25-
elif dtype == float:
26-
dtype = DType.Float
27-
28-
if dtype == DType.Bool:
20+
dtype_ = dtype.convert(dtype_)
21+
22+
if dtype_ == dtype.bool_:
2923
return VarBool(name)
30-
elif dtype == DType.Int:
24+
elif dtype_ == dtype.int64:
3125
return VarInt(name)
32-
elif dtype == DType.UnsignedInt:
26+
elif dtype_ == dtype.uint64:
3327
return VarUInt(name)
34-
elif dtype == DType.Float:
28+
elif dtype_ == dtype.float64:
3529
return VarFloat(name)
36-
raise TypeError(f"unsupported variable type `{dtype}`")
30+
raise TypeError(f"unsupported variable type `{dtype_}`")
3731

3832

3933
def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstFloat]:
@@ -48,7 +42,7 @@ def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstF
4842

4943

5044
def signal(
51-
dtype: Union[DType, Type[AllowedDtype]],
45+
dtype_: Union[dtype, Type[AllowedDtype]],
5246
*,
5347
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
5448
) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
@@ -57,7 +51,7 @@ def signal(
5751
Parameters
5852
----------
5953
60-
dtype:
54+
dtype_:
6155
Type of the signal
6256
6357
data :
@@ -67,21 +61,21 @@ def signal(
6761
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
6862
expected_type: Type[AllowedDtype]
6963

70-
dtype = DType.convert(dtype)
71-
if dtype == DType.Bool:
64+
dtype_ = dtype.convert(dtype_)
65+
if dtype_ == dtype.bool_:
7266
factory = BoolSignal
7367
expected_type = bool
74-
elif dtype == DType.UnsignedInt:
68+
elif dtype_ == dtype.uint64:
7569
factory = UnsignedIntSignal
7670
expected_type = int
77-
elif dtype == DType.Int:
71+
elif dtype_ == dtype.int64:
7872
factory = IntSignal
7973
expected_type = int
80-
elif dtype == DType.Float:
74+
elif dtype_ == dtype.float64:
8175
factory = FloatSignal
8276
expected_type = float
8377
else:
84-
raise ValueError(f"unsupported dtype {dtype}")
78+
raise ValueError(f"unsupported dtype_ {dtype}")
8579

8680
if data is None:
8781
return factory.from_samples([])
@@ -92,7 +86,7 @@ def signal(
9286

9387

9488
__all__ = [
95-
"DType",
89+
"dtype",
9690
"declare_var",
9791
"literal",
9892
"signal",

pyargus/argus/_argus.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ class Until(BoolExpr):
123123
def __init__(self, lhs: BoolExpr, rhs: BoolExpr) -> None: ...
124124

125125
@final
126-
class DType:
127-
Bool: ClassVar[DType] = ...
128-
Float: ClassVar[DType] = ...
129-
Int: ClassVar[DType] = ...
130-
UnsignedInt: ClassVar[DType] = ...
126+
class dtype: # noqa: N801
127+
bool_: ClassVar[dtype] = ...
128+
float64: ClassVar[dtype] = ...
129+
int64: ClassVar[dtype] = ...
130+
uint64: ClassVar[dtype] = ...
131131

132132
@classmethod
133133
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
@@ -143,7 +143,7 @@ class Signal(Generic[_SignalKind], Protocol):
143143
@property
144144
def end_time(self) -> float | None: ...
145145
@property
146-
def kind(self) -> type[bool | int | float]: ...
146+
def kind(self) -> dtype: ...
147147

148148
@final
149149
class BoolSignal(Signal[bool]):

pyargus/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@ impl From<PyArgusError> for PyErr {
2828
}
2929
}
3030

31-
#[pyclass(module = "argus")]
31+
#[pyclass(module = "argus", name = "dtype")]
3232
#[derive(Copy, Clone, Debug)]
3333
pub enum DType {
34+
#[pyo3(name = "bool_")]
3435
Bool,
36+
#[pyo3(name = "int64")]
3537
Int,
38+
#[pyo3(name = "uint64")]
3639
UnsignedInt,
40+
#[pyo3(name = "float64")]
3741
Float,
3842
}
3943

pyargus/src/signals.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use argus_core::signals::interpolation::Linear;
22
use argus_core::signals::Signal;
33
use pyo3::prelude::*;
4-
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
4+
use pyo3::types::PyType;
55

6-
use crate::PyArgusError;
6+
use crate::{DType, PyArgusError};
77

88
#[pyclass(name = "InterpolationMethod", module = "argus")]
99
#[derive(Debug, Clone, Copy, Default)]
@@ -21,22 +21,30 @@ pub enum SignalKind {
2121
Float(Signal<f64>),
2222
}
2323

24+
impl SignalKind {
25+
/// Get the kind of the signal
26+
pub fn kind(&self) -> DType {
27+
match self {
28+
SignalKind::Bool(_) => DType::Bool,
29+
SignalKind::Int(_) => DType::Int,
30+
SignalKind::UnsignedInt(_) => DType::UnsignedInt,
31+
SignalKind::Float(_) => DType::Float,
32+
}
33+
}
34+
}
35+
2436
#[pyclass(name = "Signal", subclass, module = "argus")]
2537
#[derive(Debug, Clone)]
2638
pub struct PySignal {
27-
pub interpolation: PyInterp,
28-
pub signal: SignalKind,
39+
pub(crate) interpolation: PyInterp,
40+
pub(crate) signal: SignalKind,
2941
}
3042

3143
#[pymethods]
3244
impl PySignal {
3345
#[getter]
34-
fn kind<'py>(&self, py: Python<'py>) -> &'py PyType {
35-
match self.signal {
36-
SignalKind::Bool(_) => PyType::new::<PyBool>(py),
37-
SignalKind::Int(_) | SignalKind::UnsignedInt(_) => PyType::new::<PyInt>(py),
38-
SignalKind::Float(_) => PyType::new::<PyFloat>(py),
39-
}
46+
fn kind(&self) -> DType {
47+
self.signal.kind()
4048
}
4149

4250
fn __repr__(&self) -> String {

pyargus/tests/test_signals.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,39 @@
66
from hypothesis.strategies import SearchStrategy, composite
77

88
import argus
9-
from argus import DType
9+
from argus import AllowedDtype, dtype
1010

11-
AllowedDtype = Union[bool, int, float]
1211

13-
14-
def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[AllowedDtype]:
15-
new_dtype = DType.convert(dtype)
16-
if new_dtype == DType.Bool:
12+
def gen_element_fn(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[AllowedDtype]:
13+
new_dtype = dtype.convert(dtype_)
14+
if new_dtype == dtype.bool_:
1715
return st.booleans()
18-
elif new_dtype == DType.Int:
16+
elif new_dtype == dtype.int64:
1917
size = 2**64
2018
return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
21-
elif new_dtype == DType.UnsignedInt:
19+
elif new_dtype == dtype.uint64:
2220
size = 2**64
2321
return st.integers(min_value=0, max_value=(size - 1))
24-
elif new_dtype == DType.Float:
22+
elif new_dtype == dtype.float64:
2523
return st.floats(
2624
width=64,
2725
allow_nan=False,
2826
allow_infinity=False,
2927
allow_subnormal=False,
3028
)
3129
else:
32-
raise ValueError(f"invalid dtype {dtype}")
30+
raise ValueError(f"invalid dtype {dtype_}")
3331

3432

3533
@composite
3634
def gen_samples(
37-
draw: st.DrawFn, *, min_size: int, max_size: int, dtype: Union[Type[AllowedDtype], DType]
35+
draw: st.DrawFn, min_size: int, max_size: int, dtype_: Union[Type[AllowedDtype], dtype]
3836
) -> List[Tuple[float, AllowedDtype]]:
3937
"""
4038
Generate arbitrary samples for a signal where the time stamps are strictly
4139
monotonically increasing
4240
"""
43-
elements = gen_element_fn(dtype)
41+
elements = gen_element_fn(dtype_)
4442
values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
4543
xs = draw(
4644
st.lists(
@@ -55,28 +53,28 @@ def gen_samples(
5553
return xs
5654

5755

58-
def empty_signal(*, dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
59-
new_dtype: DType = DType.convert(dtype)
56+
def empty_signal(*, dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
57+
new_dtype: dtype = dtype.convert(dtype_)
6058
sig: argus.Signal
61-
if new_dtype == DType.Bool:
59+
if new_dtype == dtype.bool_:
6260
sig = argus.BoolSignal()
63-
assert sig.kind is bool
64-
elif new_dtype == DType.UnsignedInt:
61+
assert sig.kind == dtype.bool_
62+
elif new_dtype == dtype.uint64:
6563
sig = argus.UnsignedIntSignal()
66-
assert sig.kind is int
67-
elif new_dtype == DType.Int:
64+
assert sig.kind == dtype.uint64
65+
elif new_dtype == dtype.int64:
6866
sig = argus.IntSignal()
69-
assert sig.kind is int
70-
elif new_dtype == DType.Float:
67+
assert sig.kind == dtype.int64
68+
elif new_dtype == dtype.float64:
7169
sig = argus.FloatSignal()
72-
assert sig.kind is float
70+
assert sig.kind == dtype.float64
7371
else:
7472
raise ValueError("unknown dtype")
7573
return st.just(sig)
7674

7775

78-
def constant_signal(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
79-
return gen_element_fn(dtype).map(lambda val: argus.signal(dtype, data=val))
76+
def constant_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
77+
return gen_element_fn(dtype_).map(lambda val: argus.signal(dtype_, data=val))
8078

8179

8280
@composite
@@ -87,16 +85,16 @@ def draw_index(draw: st.DrawFn, vec: List) -> int:
8785
return draw(st.just(0))
8886

8987

90-
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], DType]]:
88+
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
9189
return st.one_of(
92-
list(map(st.just, [DType.Bool, DType.UnsignedInt, DType.Int, DType.Float, bool, int, float])), # type: ignore[arg-type]
90+
list(map(st.just, [dtype.bool_, dtype.uint64, dtype.int64, dtype.float64, bool, int, float])), # type: ignore[arg-type]
9391
)
9492

9593

9694
@given(st.data())
9795
def test_correct_constant_signals(data: st.DataObject) -> None:
98-
dtype = data.draw(gen_dtype())
99-
signal = data.draw(constant_signal(dtype))
96+
dtype_ = data.draw(gen_dtype())
97+
signal = data.draw(constant_signal(dtype_))
10098

10199
assert not signal.is_empty()
102100
assert signal.start_time is None
@@ -105,11 +103,11 @@ def test_correct_constant_signals(data: st.DataObject) -> None:
105103

106104
@given(st.data())
107105
def test_correctly_create_signals(data: st.DataObject) -> None:
108-
dtype = data.draw(gen_dtype())
109-
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype=dtype))
106+
dtype_ = data.draw(gen_dtype())
107+
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype_=dtype_))
110108

111109
note(f"Samples: {gen_samples}")
112-
signal = argus.signal(dtype, data=xs)
110+
signal = argus.signal(dtype_, data=xs)
113111
if len(xs) > 0:
114112
expected_start_time = xs[0][0]
115113
expected_end_time = xs[-1][0]
@@ -132,7 +130,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
132130

133131
# generate one more sample
134132
new_time = actual_end_time + 1
135-
new_value = data.draw(gen_element_fn(dtype))
133+
new_value = data.draw(gen_element_fn(dtype_))
136134
signal.push(new_time, new_value) # type: ignore[arg-type]
137135

138136
get_val = signal.at(new_time)
@@ -148,8 +146,8 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
148146

149147
@given(st.data())
150148
def test_signal_create_should_fail(data: st.DataObject) -> None:
151-
dtype = data.draw(gen_dtype())
152-
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype))
149+
dtype_ = data.draw(gen_dtype())
150+
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype_=dtype_))
153151
a = data.draw(draw_index(xs))
154152
b = data.draw(draw_index(xs))
155153
assume(a != b)
@@ -161,24 +159,24 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
161159
xs[b], xs[a] = xs[a], xs[b]
162160

163161
with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
164-
_ = argus.signal(dtype, data=xs)
162+
_ = argus.signal(dtype_, data=xs)
165163

166164

167165
@given(st.data())
168166
def test_push_to_empty_signal(data: st.DataObject) -> None:
169-
dtype = data.draw(gen_dtype())
170-
sig = data.draw(empty_signal(dtype=dtype))
167+
dtype_ = data.draw(gen_dtype())
168+
sig = data.draw(empty_signal(dtype_=dtype_))
171169
assert sig.is_empty()
172-
element = data.draw(gen_element_fn(dtype))
170+
element = data.draw(gen_element_fn(dtype_))
173171
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
174172
sig.push(0.0, element) # type: ignore[attr-defined]
175173

176174

177175
@given(st.data())
178176
def test_push_to_constant_signal(data: st.DataObject) -> None:
179-
dtype = data.draw(gen_dtype())
180-
sig = data.draw(constant_signal(dtype=dtype))
177+
dtype_ = data.draw(gen_dtype())
178+
sig = data.draw(constant_signal(dtype_=dtype_))
181179
assert not sig.is_empty()
182-
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype=dtype))[0]
180+
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
183181
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
184182
sig.push(*sample) # type: ignore[attr-defined]

0 commit comments

Comments
 (0)