Skip to content

Commit 67df337

Browse files
authored
waveform: Implement copy and pickle (#13)
* tests: Add copy/pickle test cases * pyproject.toml: Use dev branch of hightime for copy/pickle fixes * waveform: Fix waveform timing copy/pickle * tests: Add more scale mode copy/pickle tests * waveform: Fix AnalogWaveform copy/pickle * waveform: Use Self instead of duplicating type * waveform: Document that waveform timing objects are immutable and copy the timestamps list by default * tests: Refactor some duplicated params * waveform: Use _typing.Self * waveform: Enable shallow copy of extended properties * Only use dev branch of hightime for testing, not in distribution package * pyproject.toml: Use hightime main branch for testing * Update poetry.lock * tests: Use assert_deep_copy in pickle test
1 parent 375f6da commit 67df337

File tree

13 files changed

+511
-61
lines changed

13 files changed

+511
-61
lines changed

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ numpy = [
1414
{ version = ">=1.26", python = ">=3.12,<3.13" },
1515
{ version = ">=2.1", python = "^3.13" },
1616
]
17-
# hightime = "^0.2.2"
18-
hightime = { git = "https://github.com/ni/hightime.git" }
17+
hightime = "^0.2.2"
1918

2019
[tool.poetry.group.lint.dependencies]
2120
bandit = { version = ">=1.7", extras = ["toml"] }
@@ -26,6 +25,8 @@ mypy = ">=1.0"
2625
pytest = ">=7.2"
2726
pytest-cov = ">=4.0"
2827
pytest-mock = ">=3.0"
28+
# Use an unreleased version of hightime for testing.
29+
hightime = { git = "https://github.com/ni/hightime.git" }
2930

3031
[tool.poetry.group.docs]
3132
optional = true

src/nitypes/waveform/_analog_waveform.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from nitypes._arguments import arg_to_uint, validate_dtype, validate_unsupported_arg
1414
from nitypes._exceptions import invalid_arg_type, invalid_array_ndim
15-
from nitypes._typing import TypeAlias
15+
from nitypes._typing import Self, TypeAlias
1616
from nitypes.waveform._extended_properties import (
1717
CHANNEL_NAME,
1818
UNIT_DESCRIPTION,
@@ -351,6 +351,7 @@ def __init__(
351351
start_index: SupportsIndex | None = None,
352352
capacity: SupportsIndex | None = None,
353353
extended_properties: Mapping[str, ExtendedPropertyValue] | None = None,
354+
copy_extended_properties: bool = True,
354355
timing: Timing | PrecisionTiming | None = None,
355356
scale_mode: ScaleMode | None = None,
356357
) -> None:
@@ -368,6 +369,8 @@ def __init__(
368369
capacity: The number of samples to allocate. Pre-allocating a larger buffer optimizes
369370
appending samples to the waveform.
370371
extended_properties: The extended properties of the analog waveform.
372+
copy_extended_properties: Specifies whether to copy the extended properties or take
373+
ownership.
371374
timing: The timing information of the analog waveform.
372375
scale_mode: The scale mode of the analog waveform.
373376
@@ -389,7 +392,11 @@ def __init__(
389392
else:
390393
raise invalid_arg_type("raw data", "NumPy ndarray", raw_data)
391394

392-
self._extended_properties = ExtendedPropertyDictionary(extended_properties)
395+
if copy_extended_properties or not isinstance(
396+
extended_properties, ExtendedPropertyDictionary
397+
):
398+
extended_properties = ExtendedPropertyDictionary(extended_properties)
399+
self._extended_properties = extended_properties
393400

394401
if timing is None:
395402
timing = Timing.empty
@@ -864,6 +871,22 @@ def __eq__(self, value: object, /) -> bool:
864871
and self._scale_mode == value._scale_mode
865872
)
866873

874+
def __reduce__(self) -> tuple[Any, ...]:
875+
"""Return object state for pickling."""
876+
ctor_args = (self._sample_count, self.dtype)
877+
ctor_kwargs: dict[str, Any] = {
878+
"raw_data": self.raw_data,
879+
"extended_properties": self._extended_properties,
880+
"copy_extended_properties": False,
881+
"timing": self._timing,
882+
"scale_mode": self._scale_mode,
883+
}
884+
return (self.__class__._unpickle, (ctor_args, ctor_kwargs))
885+
886+
@classmethod
887+
def _unpickle(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Self:
888+
return cls(*args, **kwargs)
889+
867890
def __repr__(self) -> str:
868891
"""Return repr(self)."""
869892
args = [f"{self._sample_count}"]

src/nitypes/waveform/_extended_properties.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
class ExtendedPropertyDictionary(MutableMapping[str, ExtendedPropertyValue]):
2020
"""A dictionary of extended properties."""
2121

22+
__slots__ = ["_properties"]
23+
2224
def __init__(self, properties: Mapping[str, ExtendedPropertyValue] | None = None, /) -> None:
2325
"""Construct an ExtendedPropertyDictionary."""
2426
self._properties: dict[str, ExtendedPropertyValue] = {}

src/nitypes/waveform/_timing/_base.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import operator
55
from abc import ABC, abstractmethod
66
from collections.abc import Iterable, Sequence
7-
from typing import Generic, SupportsIndex, TypeVar
7+
from typing import Any, Generic, SupportsIndex, TypeVar
88

99
from nitypes._exceptions import add_note
1010
from nitypes._typing import Self
@@ -14,13 +14,15 @@
1414
create_sample_interval_strategy,
1515
)
1616

17-
1817
_TDateTime = TypeVar("_TDateTime", bound=dt.datetime)
1918
_TTimeDelta = TypeVar("_TTimeDelta", bound=dt.timedelta)
2019

2120

2221
class BaseTiming(ABC, Generic[_TDateTime, _TTimeDelta]):
23-
"""Base class for waveform timing information."""
22+
"""Base class for waveform timing information.
23+
24+
Waveform timing objects are immutable.
25+
"""
2426

2527
@classmethod
2628
@abstractmethod
@@ -118,8 +120,26 @@ def __init__(
118120
time_offset: _TTimeDelta | None,
119121
sample_interval: _TTimeDelta | None,
120122
timestamps: Sequence[_TDateTime] | None,
123+
*,
124+
copy_timestamps: bool = True,
121125
) -> None:
122-
"""Construct a base waveform timing object."""
126+
"""Construct a waveform timing object.
127+
128+
Args:
129+
sample_interval_mode: The sample interval mode of the waveform timing.
130+
timestamp: The timestamp of the waveform timing. This argument is optional for
131+
SampleIntervalMode.NONE and SampleIntervalMode.REGULAR and unsupported for
132+
SampleIntervalMode.IRREGULAR.
133+
time_offset: The time difference between the timestamp and the first sample. This
134+
argument is optional for SampleIntervalMode.NONE and SampleIntervalMode.REGULAR and
135+
unsupported for SampleIntervalMode.IRREGULAR.
136+
sample_interval: The time interval between samples. This argument is required for
137+
SampleIntervalMode.REGULAR and unsupported otherwise.
138+
timestamps: A sequence containing a timestamp for each sample in the waveform,
139+
specifying the time that the sample was acquired. This argument is required for
140+
SampleIntervalMode.IRREGULAR and unsupported otherwise.
141+
copy_timestamps: Specifies whether to copy the timestamps or take ownership.
142+
"""
123143
sample_interval_strategy = create_sample_interval_strategy(sample_interval_mode)
124144
try:
125145
sample_interval_strategy.validate_init_args(
@@ -129,7 +149,7 @@ def __init__(
129149
add_note(e, f"Sample interval mode: {sample_interval_mode}")
130150
raise
131151

132-
if timestamps is not None and not isinstance(timestamps, list):
152+
if timestamps is not None and (copy_timestamps or not isinstance(timestamps, list)):
133153
timestamps = list(timestamps)
134154

135155
self._sample_interval_strategy = sample_interval_strategy
@@ -212,6 +232,24 @@ def __eq__(self, value: object, /) -> bool:
212232
and self._timestamps == value._timestamps
213233
)
214234

235+
def __reduce__(self) -> tuple[Any, ...]:
236+
"""Return object state for pickling."""
237+
ctor_args = (
238+
self._sample_interval_mode,
239+
self._timestamp,
240+
self._time_offset,
241+
self._sample_interval,
242+
self._timestamps,
243+
)
244+
ctor_kwargs: dict[str, Any] = {}
245+
if self._timestamps is not None:
246+
ctor_kwargs["copy_timestamps"] = False
247+
return (self.__class__._unpickle, (ctor_args, ctor_kwargs))
248+
249+
@classmethod
250+
def _unpickle(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Self:
251+
return cls(*args, **kwargs)
252+
215253
def __repr__(self) -> str:
216254
"""Return repr(self)."""
217255
# For Enum, __str__ is an unqualified ctor expression like E.V and __repr__ is <E.V: 0>.

src/nitypes/waveform/_timing/_precision.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class PrecisionTiming(BaseTiming[ht.datetime, ht.timedelta]):
1414
"""High-precision waveform timing using the hightime package.
1515
1616
The hightime package has up to yoctosecond precision.
17+
18+
Waveform timing objects are immutable.
1719
"""
1820

1921
_DEFAULT_TIME_OFFSET = ht.timedelta()
@@ -66,6 +68,8 @@ def __init__(
6668
time_offset: ht.timedelta | None = None,
6769
sample_interval: ht.timedelta | None = None,
6870
timestamps: Sequence[ht.datetime] | None = None,
71+
*,
72+
copy_timestamps: bool = True,
6973
) -> None:
7074
"""Construct a high-precision waveform timing object.
7175
@@ -74,7 +78,14 @@ def __init__(
7478
- PrecisionTiming.create_with_regular_interval
7579
- PrecisionTiming.create_with_irregular_interval
7680
"""
77-
super().__init__(sample_interval_mode, timestamp, time_offset, sample_interval, timestamps)
81+
super().__init__(
82+
sample_interval_mode,
83+
timestamp,
84+
time_offset,
85+
sample_interval,
86+
timestamps,
87+
copy_timestamps=copy_timestamps,
88+
)
7889

7990

8091
PrecisionTiming.empty = PrecisionTiming.create_with_no_interval()

src/nitypes/waveform/_timing/_standard.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class Timing(BaseTiming[dt.datetime, dt.timedelta]):
1414
1515
The standard datetime module has up to microsecond precision. For higher precision, use
1616
PrecisionTiming.
17+
18+
Waveform timing objects are immutable.
1719
"""
1820

1921
_DEFAULT_TIME_OFFSET = dt.timedelta()
@@ -66,6 +68,8 @@ def __init__(
6668
time_offset: dt.timedelta | None = None,
6769
sample_interval: dt.timedelta | None = None,
6870
timestamps: Sequence[dt.datetime] | None = None,
71+
*,
72+
copy_timestamps: bool = True,
6973
) -> None:
7074
"""Construct a waveform timing object.
7175
@@ -74,7 +78,14 @@ def __init__(
7478
- Timing.create_with_regular_interval
7579
- Timing.create_with_irregular_interval
7680
"""
77-
super().__init__(sample_interval_mode, timestamp, time_offset, sample_interval, timestamps)
81+
super().__init__(
82+
sample_interval_mode,
83+
timestamp,
84+
time_offset,
85+
sample_interval,
86+
timestamps,
87+
copy_timestamps=copy_timestamps,
88+
)
7889

7990

8091
Timing.empty = Timing.create_with_no_interval()

tests/unit/waveform/_scaling/test_linear.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3+
import copy
4+
import pickle
35
from typing import SupportsFloat
46

57
import numpy as np
68
import numpy.typing as npt
79
import pytest
810

911
from nitypes._typing import assert_type
10-
from nitypes.waveform import LinearScaleMode
12+
from nitypes.waveform import NO_SCALING, LinearScaleMode, ScaleMode
1113

1214

1315
@pytest.mark.parametrize(
@@ -80,7 +82,71 @@ def test___float64_ndarray___transform_data___returns_float64_scaled_data() -> N
8082
assert list(scaled_data) == [4.0, 7.0, 10.0, 13.0]
8183

8284

85+
@pytest.mark.parametrize(
86+
"left, right",
87+
[
88+
(LinearScaleMode(1.0, 0.0), LinearScaleMode(1.0, 0.0)),
89+
(LinearScaleMode(1.2345, 0.006789), LinearScaleMode(1.2345, 0.006789)),
90+
],
91+
)
92+
def test___same_value___equality___equal(left: LinearScaleMode, right: LinearScaleMode) -> None:
93+
assert left == right
94+
assert not (left != right)
95+
96+
97+
@pytest.mark.parametrize(
98+
"left, right",
99+
[
100+
(LinearScaleMode(1.0, 0.0), LinearScaleMode(1.0, 0.1)),
101+
(LinearScaleMode(1.0, 0.0), LinearScaleMode(1.1, 0.0)),
102+
(LinearScaleMode(1.2345, 0.006789), LinearScaleMode(1.23456, 0.006789)),
103+
(LinearScaleMode(1.2345, 0.006789), LinearScaleMode(1.2345, 0.00678)),
104+
(LinearScaleMode(1.0, 0.0), NO_SCALING),
105+
(NO_SCALING, LinearScaleMode(1.0, 0.0)),
106+
],
107+
)
108+
def test___different_value___equality___not_equal(left: ScaleMode, right: ScaleMode) -> None:
109+
assert not (left == right)
110+
assert left != right
111+
112+
83113
def test___scale_mode___repr___looks_ok() -> None:
84114
scale_mode = LinearScaleMode(1.2345, 0.006789)
85115

86116
assert repr(scale_mode) == "nitypes.waveform.LinearScaleMode(1.2345, 0.006789)"
117+
118+
119+
def test___scale_mode___copy___makes_shallow_copy() -> None:
120+
scale_mode = LinearScaleMode(1.2345, 0.006789)
121+
122+
new_scale_mode = copy.copy(scale_mode)
123+
124+
assert new_scale_mode == scale_mode
125+
assert new_scale_mode is not scale_mode
126+
127+
128+
def test___scale_mode___deepcopy___makes_deep_copy() -> None:
129+
scale_mode = LinearScaleMode(1.2345, 0.006789)
130+
131+
new_scale_mode = copy.deepcopy(scale_mode)
132+
133+
assert new_scale_mode == scale_mode
134+
assert new_scale_mode is not scale_mode
135+
136+
137+
def test___scale_mode___pickle_unpickle___makes_deep_copy() -> None:
138+
scale_mode = LinearScaleMode(1.2345, 0.006789)
139+
140+
new_scale_mode = pickle.loads(pickle.dumps(scale_mode))
141+
142+
assert new_scale_mode == scale_mode
143+
assert new_scale_mode is not scale_mode
144+
145+
146+
def test___scale_mode___pickle___references_public_modules() -> None:
147+
scale_mode = LinearScaleMode(1.2345, 0.006789)
148+
149+
scale_mode_bytes = pickle.dumps(scale_mode)
150+
151+
assert b"nitypes.waveform" in scale_mode_bytes
152+
assert b"nitypes.waveform._scaling" not in scale_mode_bytes

tests/unit/waveform/_scaling/test_none.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import copy
4+
import pickle
5+
36
import numpy as np
47
import numpy.typing as npt
58

@@ -44,3 +47,31 @@ def test___float64_ndarray___transform_data___returns_float64_scaled_data() -> N
4447

4548
def test___scale_mode___repr___looks_ok() -> None:
4649
assert repr(NO_SCALING) == "nitypes.waveform.NoneScaleMode()"
50+
51+
52+
def test___scale_mode___copy___makes_shallow_copy() -> None:
53+
new_scale_mode = copy.copy(NO_SCALING)
54+
55+
assert new_scale_mode == NO_SCALING
56+
assert new_scale_mode is not NO_SCALING
57+
58+
59+
def test___scale_mode___deepcopy___makes_deep_copy() -> None:
60+
new_scale_mode = copy.deepcopy(NO_SCALING)
61+
62+
assert new_scale_mode == NO_SCALING
63+
assert new_scale_mode is not NO_SCALING
64+
65+
66+
def test___scale_mode___pickle_unpickle___makes_deep_copy() -> None:
67+
new_scale_mode = pickle.loads(pickle.dumps(NO_SCALING))
68+
69+
assert new_scale_mode == NO_SCALING
70+
assert new_scale_mode is not NO_SCALING
71+
72+
73+
def test___scale_mode___pickle___references_public_modules() -> None:
74+
scale_mode_bytes = pickle.dumps(NO_SCALING)
75+
76+
assert b"nitypes.waveform" in scale_mode_bytes
77+
assert b"nitypes.waveform._scaling" not in scale_mode_bytes

0 commit comments

Comments
 (0)