Skip to content

Commit f0c5d13

Browse files
committed
waveform: Limit _TScaled_co to supported types and add _TOtherScaled
1 parent 241678d commit f0c5d13

File tree

4 files changed

+64
-24
lines changed

4 files changed

+64
-24
lines changed

src/nitypes/waveform/_analog_waveform.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,32 @@
3636
if sys.version_info < (3, 10):
3737
import array as std_array
3838

39-
39+
# _TRaw and _TRaw_co specify the type of the raw_data array. They are not limited to supported
40+
# types. Requesting an unsupported type raises TypeError at run time.
4041
_TRaw = TypeVar("_TRaw", bound=np.generic)
4142
_TRaw_co = TypeVar("_TRaw_co", bound=np.generic, covariant=True)
4243

43-
# default requires Python 3.13+ or typing_extensions.
44-
_TScaled = TypeVar("_TScaled", bound=np.generic, default=np.float64)
45-
_TScaled_co = TypeVar("_TScaled_co", bound=np.generic, default=np.float64, covariant=True)
44+
# _TScaled_co specifies the type of the scaled_data property, which is a double precision
45+
# floating-point number or complex number. This type variable has a default so that clients can
46+
# omit it and write type hints like AnalogWaveform[np.int32]. Default requires Python 3.13+ or
47+
# typing_extensions.
48+
_TScaled_co = TypeVar(
49+
"_TScaled_co", bound=Union[np.float64, np.complex128], default=np.float64, covariant=True
50+
)
4651

47-
# Note: ComplexInt32Base is currently an alias for np.void, so this currently matches any NumPy
48-
# structured data type based on np.void. However, constructing with a different structured data
49-
# type will fail at run time.
50-
_TComplexRaw = TypeVar("_TComplexRaw", bound=Union[np.complexfloating, ComplexInt32Base])
51-
_TComplexRaw_co = TypeVar(
52-
"_TComplexRaw_co", bound=Union[np.complexfloating, ComplexInt32Base], covariant=True
52+
# _TOtherScaled is for the get_scaled_data() method, which supports both single and
53+
# double precision.
54+
_TOtherScaled = TypeVar(
55+
"_TOtherScaled", bound=Union[np.float32, np.float64, np.complex64, np.complex128]
5356
)
5457

58+
# _TComplexRaw is for constructor overloads that only match complex numbers. These overloads enable
59+
# the type checker to infer that _TScaled_co should also be a complex type.
60+
#
61+
# Note: ComplexInt32Base is currently an alias for np.void, so this type variables matches any NumPy
62+
# structured data type that is based on np.void. Mismatches raise TypeError at run time.
63+
_TComplexRaw = TypeVar("_TComplexRaw", bound=Union[np.complex64, np.complex128, ComplexInt32Base])
64+
5565
_AnyTiming: TypeAlias = Union[BaseTiming[Any, Any], Timing, PrecisionTiming]
5666
_TTiming = TypeVar("_TTiming", bound=BaseTiming[Any, Any])
5767

@@ -402,9 +412,9 @@ def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa
402412

403413
@overload
404414
def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa)
405-
self: AnalogWaveform[_TComplexRaw_co, np.complex128],
415+
self: AnalogWaveform[_TComplexRaw, np.complex128],
406416
sample_count: SupportsIndex | None = ...,
407-
dtype: type[_TComplexRaw_co] | np.dtype[_TComplexRaw_co] = ...,
417+
dtype: type[_TComplexRaw] | np.dtype[_TComplexRaw] = ...,
408418
*,
409419
raw_data: None = ...,
410420
start_index: SupportsIndex | None = ...,
@@ -416,11 +426,11 @@ def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa
416426

417427
@overload
418428
def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa)
419-
self: AnalogWaveform[_TComplexRaw_co, np.complex128],
429+
self: AnalogWaveform[_TRaw, np.float64],
420430
sample_count: SupportsIndex | None = ...,
421-
dtype: None = ...,
431+
dtype: type[_TRaw] | np.dtype[_TRaw] = ...,
422432
*,
423-
raw_data: npt.NDArray[_TComplexRaw_co] = ...,
433+
raw_data: None = ...,
424434
start_index: SupportsIndex | None = ...,
425435
capacity: SupportsIndex | None = ...,
426436
extended_properties: Mapping[str, ExtendedPropertyValue] | None = ...,
@@ -430,11 +440,11 @@ def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa
430440

431441
@overload
432442
def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa)
433-
self: AnalogWaveform[_TRaw_co, np.float64],
443+
self: AnalogWaveform[_TComplexRaw, np.complex128],
434444
sample_count: SupportsIndex | None = ...,
435-
dtype: type[_TRaw_co] | np.dtype[_TRaw_co] = ...,
445+
dtype: None = ...,
436446
*,
437-
raw_data: None = ...,
447+
raw_data: npt.NDArray[_TComplexRaw] = ...,
438448
start_index: SupportsIndex | None = ...,
439449
capacity: SupportsIndex | None = ...,
440450
extended_properties: Mapping[str, ExtendedPropertyValue] | None = ...,
@@ -444,11 +454,11 @@ def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa
444454

445455
@overload
446456
def __init__( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa)
447-
self: AnalogWaveform[_TRaw_co, np.float64],
457+
self: AnalogWaveform[_TRaw, np.float64],
448458
sample_count: SupportsIndex | None = ...,
449459
dtype: None = ...,
450460
*,
451-
raw_data: npt.NDArray[_TRaw_co] = ...,
461+
raw_data: npt.NDArray[_TRaw] = ...,
452462
start_index: SupportsIndex | None = ...,
453463
capacity: SupportsIndex | None = ...,
454464
extended_properties: Mapping[str, ExtendedPropertyValue] | None = ...,
@@ -680,11 +690,11 @@ def get_scaled_data( # noqa: D107 - Missing docstring in __init__ (auto-generat
680690
@overload
681691
def get_scaled_data( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa)
682692
self,
683-
dtype: type[_TScaled] | np.dtype[_TScaled] = ...,
693+
dtype: type[_TOtherScaled] | np.dtype[_TOtherScaled] = ...,
684694
*,
685695
start_index: SupportsIndex | None = ...,
686696
sample_count: SupportsIndex | None = ...,
687-
) -> npt.NDArray[_TScaled]: ...
697+
) -> npt.NDArray[_TOtherScaled]: ...
688698

689699
@overload
690700
def get_scaled_data( # noqa: D107 - Missing docstring in __init__ (auto-generated noqa)

src/nitypes/waveform/_complex_waveform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import numpy as np
44
from typing_extensions import TypeAlias
55

6-
from nitypes.waveform._analog_waveform import AnalogWaveform, _TComplexRaw_co
6+
from nitypes.waveform._analog_waveform import AnalogWaveform, _TComplexRaw
77

8-
ComplexWaveform: TypeAlias = AnalogWaveform[_TComplexRaw_co, np.complex128]
8+
ComplexWaveform: TypeAlias = AnalogWaveform[_TComplexRaw, np.complex128]
99
"""An analog waveform containing complex-number data.
1010
1111
.. note::

tests/unit/waveform/test_analog_waveform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ def test___sample_count_and_unsupported_dtype___create___raises_type_error() ->
9595
assert exc.value.args[0].startswith("The requested data type is not supported.")
9696

9797

98+
def test___dtype_str_with_traw_hint___create___narrows_traw_and_tscaled() -> None:
99+
waveform: AnalogWaveform[np.int32] = AnalogWaveform(dtype="int32")
100+
101+
assert_type(waveform, AnalogWaveform[np.int32, np.float64])
102+
103+
104+
def test___dtype_str_with_unsupported_tscaled_hint___create___mypy_returns_error() -> None:
105+
waveform: AnalogWaveform[np.int32, np.float32] = AnalogWaveform(dtype="int32") # type: ignore[type-var]
106+
_ = waveform
107+
108+
98109
###############################################################################
99110
# from_array_1d
100111
###############################################################################

tests/unit/waveform/test_complex_waveform.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,25 @@ def test___sample_count_and_unknown_structured_dtype___create___raises_type_erro
5858
assert "Data type: [('a', '<i2'), ('b', '<i4')]" in exc.value.args[0]
5959

6060

61+
def test___sample_count_and_structured_dtype_str___create___raises_type_error() -> None:
62+
with pytest.raises(TypeError) as exc:
63+
_ = AnalogWaveform(10, "i2, i2")
64+
65+
assert exc.value.args[0].startswith("The requested data type is not supported.")
66+
assert "Data type: [('f0', '<i2'), ('f1', '<i2')]" in exc.value.args[0]
67+
68+
69+
def test___dtype_str_with_traw_hint___create___narrows_traw_and_tscaled() -> None:
70+
waveform: ComplexWaveform[np.complex64] = AnalogWaveform(dtype="complex64")
71+
72+
assert_type(waveform, AnalogWaveform[np.complex64, np.complex128])
73+
74+
75+
def test___dtype_str_with_unsupported_tscaled_hint___create___mypy_returns_error() -> None:
76+
waveform: AnalogWaveform[np.complex64, np.complex64] = AnalogWaveform(dtype="complex64") # type: ignore[type-var]
77+
_ = waveform
78+
79+
6180
###############################################################################
6281
# from_array_1d
6382
###############################################################################

0 commit comments

Comments
 (0)