Skip to content

Commit 219fbe4

Browse files
committed
1 parent 849326f commit 219fbe4

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

tests/serde_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Serialization tests, based on [py-polars/tests/unit/test_serde.py].
2+
3+
See also [Pickling Class Instances](https://docs.python.org/3/library/pickle.html#pickling-class-instances).
4+
5+
[py-polars/tests/unit/test_serde.py]: https://github.com/pola-rs/polars/blob/a143eb0d7077ee9da2ce209a19c21d7f82228081/py-polars/tests/unit/test_serde.py
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import pickle
11+
12+
# ruff: noqa: S301
13+
from typing import TYPE_CHECKING, Protocol, TypeVar
14+
15+
import pytest
16+
17+
import narwhals as nw
18+
import narwhals.stable.v1 as nw_v1
19+
from narwhals.dtypes import DType
20+
from narwhals.typing import IntoDType, TimeUnit
21+
22+
if TYPE_CHECKING:
23+
from narwhals.typing import DTypes
24+
25+
26+
IntoDTypeT = TypeVar("IntoDTypeT", bound=IntoDType)
27+
28+
29+
namespaces = pytest.mark.parametrize("namespace", [nw, nw_v1])
30+
time_units = pytest.mark.parametrize("time_unit", ["ns", "us", "ms", "s"])
31+
32+
33+
class Identity(Protocol):
34+
def __call__(self, obj: IntoDTypeT, /) -> IntoDTypeT: ...
35+
36+
37+
def _roundtrip_pickle(protocol: int | None = None) -> Identity:
38+
def fn(obj: IntoDTypeT, /) -> IntoDTypeT:
39+
result: IntoDTypeT = pickle.loads(pickle.dumps(obj, protocol))
40+
return result
41+
42+
return fn
43+
44+
45+
@pytest.fixture(
46+
params=[_roundtrip_pickle(), _roundtrip_pickle(4), _roundtrip_pickle(5)],
47+
ids=["pickle-None", "pickle-4", "pickle-5"],
48+
)
49+
def roundtrip(request: pytest.FixtureRequest) -> Identity:
50+
fn: Identity = request.param
51+
return fn
52+
53+
54+
@namespaces
55+
@time_units
56+
def test_serde_datetime_dtype(
57+
namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity
58+
) -> None:
59+
dtype = namespace.Datetime(time_unit)
60+
result = roundtrip(dtype)
61+
assert result == namespace.Datetime(time_unit)
62+
63+
64+
@namespaces
65+
@time_units
66+
def test_serde_duration_dtype(
67+
namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity
68+
) -> None:
69+
dtype = namespace.Duration(time_unit)
70+
result = roundtrip(dtype)
71+
assert result == namespace.Duration(time_unit)
72+
73+
74+
def test_serde_categorical_dtype(roundtrip: Identity) -> None:
75+
dtype = nw.Categorical()
76+
result = roundtrip(dtype)
77+
assert result == nw.Categorical
78+
79+
80+
def test_serde_doubly_nested_dtype(roundtrip: Identity) -> None:
81+
dtype = nw.Struct([nw.Field("a", nw.List(nw.String))])
82+
result = roundtrip(dtype)
83+
assert result == nw.Struct([nw.Field("a", nw.List(nw.String))])
84+
85+
86+
def test_serde_array_dtype(roundtrip: Identity) -> None:
87+
dtype = nw.Array(nw.Int32(), 3)
88+
result = roundtrip(dtype)
89+
assert result == nw.Array(nw.Int32(), 3)
90+
91+
92+
def test_serde_dtype_class(roundtrip: Identity) -> None:
93+
dtype_class = nw.Datetime
94+
result = roundtrip(dtype_class)
95+
assert result == dtype_class
96+
assert isinstance(result, type)
97+
98+
99+
def test_serde_instantiated_dtype(roundtrip: Identity) -> None:
100+
dtype = nw.Int8()
101+
result = roundtrip(dtype)
102+
assert result == dtype
103+
assert isinstance(result, DType)
104+
105+
106+
def test_serde_enum_dtype(roundtrip: Identity) -> None:
107+
dtype = nw.Enum(["a", "b"])
108+
result = roundtrip(dtype)
109+
assert result == dtype
110+
assert isinstance(result, DType)
111+
112+
113+
def test_serde_enum_v1_dtype(roundtrip: Identity) -> None:
114+
dtype = nw_v1.Enum()
115+
result = roundtrip(dtype)
116+
assert result == dtype
117+
assert isinstance(result, nw_v1.Enum)
118+
tp = type(result)
119+
with pytest.raises(TypeError):
120+
tp(["a", "b"]) # type: ignore[call-arg]

0 commit comments

Comments
 (0)