Skip to content

Commit f05011c

Browse files
test: Add serde tests for DTypes (#3205)
Co-authored-by: Francesco Bruzzesi <[email protected]>
1 parent 16d6fcb commit f05011c

File tree

4 files changed

+214
-53
lines changed

4 files changed

+214
-53
lines changed

tests/conftest.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pytest
1111

12+
import narwhals as nw
1213
from narwhals._utils import Implementation, generate_temporary_column_name
1314
from tests.utils import ID_PANDAS_LIKE, PANDAS_VERSION, pyspark_session, sqlframe_session
1415

@@ -26,11 +27,17 @@
2627

2728
from narwhals._spark_like.dataframe import SQLFrameDataFrame
2829
from narwhals._typing import EagerAllowed
29-
from narwhals.typing import NativeDataFrame, NativeLazyFrame
30-
from tests.utils import Constructor, ConstructorEager, ConstructorLazy
30+
from narwhals.typing import NativeDataFrame, NativeLazyFrame, NonNestedDType
31+
from tests.utils import (
32+
Constructor,
33+
ConstructorEager,
34+
ConstructorLazy,
35+
NestedOrEnumDType,
36+
)
3137

3238
Data: TypeAlias = "dict[str, list[Any]]"
3339

40+
3441
MIN_PANDAS_NULLABLE_VERSION = (2,)
3542

3643
# When testing cudf.pandas in Kaggle, we get an error if we try to run
@@ -321,3 +328,50 @@ def eager_backend(request: pytest.FixtureRequest) -> EagerAllowed:
321328
def eager_implementation(request: pytest.FixtureRequest) -> EagerAllowed:
322329
"""Use if a test is heavily parametric, skips `str` backend."""
323330
return request.param # type: ignore[no-any-return]
331+
332+
333+
@pytest.fixture(
334+
params=[
335+
nw.Boolean,
336+
nw.Categorical,
337+
nw.Date,
338+
nw.Datetime,
339+
nw.Decimal,
340+
nw.Duration,
341+
nw.Float32,
342+
nw.Float64,
343+
nw.Int8,
344+
nw.Int16,
345+
nw.Int32,
346+
nw.Int64,
347+
nw.Int128,
348+
nw.Object,
349+
nw.String,
350+
nw.Time,
351+
nw.UInt8,
352+
nw.UInt16,
353+
nw.UInt32,
354+
nw.UInt64,
355+
nw.UInt128,
356+
nw.Unknown,
357+
nw.Binary,
358+
],
359+
ids=lambda tp: tp.__name__,
360+
)
361+
def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]:
362+
tp_dtype: type[NonNestedDType] = request.param
363+
return tp_dtype
364+
365+
366+
@pytest.fixture(
367+
params=[
368+
nw.List(nw.Float32),
369+
nw.Array(nw.String, 2),
370+
nw.Struct({"a": nw.Boolean}),
371+
nw.Enum(["beluga", "narwhal"]),
372+
],
373+
ids=lambda obj: type(obj).__name__,
374+
)
375+
def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType:
376+
dtype: NestedOrEnumDType = request.param
377+
return dtype

tests/dtypes_test.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,58 +17,8 @@
1717
if TYPE_CHECKING:
1818
from collections.abc import Iterable
1919

20-
from typing_extensions import TypeAlias
21-
2220
from narwhals.typing import IntoFrame, IntoSeries, NonNestedDType
23-
from tests.utils import Constructor, ConstructorPandasLike
24-
25-
NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum"
26-
"""`DType`s which **cannot** be used as bare types."""
27-
28-
29-
@pytest.fixture(
30-
params=[
31-
nw.Boolean,
32-
nw.Categorical,
33-
nw.Date,
34-
nw.Datetime,
35-
nw.Decimal,
36-
nw.Duration,
37-
nw.Float32,
38-
nw.Float64,
39-
nw.Int8,
40-
nw.Int16,
41-
nw.Int32,
42-
nw.Int64,
43-
nw.Int128,
44-
nw.Object,
45-
nw.String,
46-
nw.Time,
47-
nw.UInt8,
48-
nw.UInt16,
49-
nw.UInt32,
50-
nw.UInt64,
51-
nw.UInt128,
52-
nw.Unknown,
53-
nw.Binary,
54-
],
55-
ids=lambda tp: tp.__name__,
56-
)
57-
def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]:
58-
return request.param # type: ignore[no-any-return]
59-
60-
61-
@pytest.fixture(
62-
params=[
63-
nw.List(nw.Float32),
64-
nw.Array(nw.String, 2),
65-
nw.Struct({"a": nw.Boolean}),
66-
nw.Enum(["beluga", "narwhal"]),
67-
],
68-
ids=lambda obj: type(obj).__name__,
69-
)
70-
def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType:
71-
return request.param # type: ignore[no-any-return]
21+
from tests.utils import Constructor, ConstructorPandasLike, NestedOrEnumDType
7222

7323

7424
@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"])

tests/serde_test.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Serialization tests, based on [py-polars/tests/unit/test_serde.py].
2+
3+
See also [Pickling Class Instances](https://docs.python.org/3/library/pickle.html#pickling-class-instances).
4+
5+
[py-polars/tests/unit/test_serde.py]: https://github.com/pola-rs/polars/blob/a143eb0d7077ee9da2ce209a19c21d7f82228081/py-polars/tests/unit/test_serde.py
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import pickle
11+
import string
12+
13+
# ruff: noqa: S301
14+
from typing import TYPE_CHECKING, Protocol, TypeVar
15+
16+
import pytest
17+
18+
import narwhals as nw
19+
import narwhals.stable.v1 as nw_v1
20+
from narwhals.dtypes import DType
21+
from narwhals.typing import IntoDType, NonNestedDType, TimeUnit
22+
23+
if TYPE_CHECKING:
24+
from narwhals.typing import DTypes
25+
from tests.utils import NestedOrEnumDType
26+
27+
28+
IntoDTypeT = TypeVar("IntoDTypeT", bound=IntoDType)
29+
30+
31+
namespaces = pytest.mark.parametrize("namespace", [nw, nw_v1])
32+
time_units = pytest.mark.parametrize("time_unit", ["ns", "us", "ms", "s"])
33+
34+
35+
class Identity(Protocol):
36+
def __call__(self, obj: IntoDTypeT, /) -> IntoDTypeT: ...
37+
38+
39+
def _roundtrip_pickle(protocol: int | None = None) -> Identity:
40+
def fn(obj: IntoDTypeT, /) -> IntoDTypeT:
41+
result: IntoDTypeT = pickle.loads(pickle.dumps(obj, protocol))
42+
return result
43+
44+
return fn
45+
46+
47+
@pytest.fixture(
48+
params=[_roundtrip_pickle(), _roundtrip_pickle(4), _roundtrip_pickle(5)],
49+
ids=["pickle-None", "pickle-4", "pickle-5"],
50+
)
51+
def roundtrip(request: pytest.FixtureRequest) -> Identity:
52+
fn: Identity = request.param
53+
return fn
54+
55+
56+
@namespaces
57+
@time_units
58+
def test_serde_datetime_dtype(
59+
namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity
60+
) -> None:
61+
dtype = namespace.Datetime(time_unit)
62+
result = roundtrip(dtype)
63+
assert result == dtype
64+
65+
66+
@namespaces
67+
@time_units
68+
def test_serde_duration_dtype(
69+
namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity
70+
) -> None:
71+
dtype = namespace.Duration(time_unit)
72+
result = roundtrip(dtype)
73+
assert result == dtype
74+
75+
76+
def test_serde_doubly_nested_struct_dtype(roundtrip: Identity) -> None:
77+
dtype = nw.Struct([nw.Field("a", nw.List(nw.String))])
78+
result = roundtrip(dtype)
79+
assert result == dtype
80+
81+
82+
def test_serde_doubly_nested_array_dtype(roundtrip: Identity) -> None:
83+
dtype = nw.Array(nw.Array(nw.Int32(), 2), 3)
84+
result = roundtrip(dtype)
85+
assert result == dtype
86+
87+
88+
def test_serde_dtype_class(roundtrip: Identity) -> None:
89+
dtype_class = nw.Datetime
90+
result = roundtrip(dtype_class)
91+
assert result == dtype_class
92+
assert isinstance(result, type)
93+
94+
95+
def test_serde_enum_dtype(roundtrip: Identity) -> None:
96+
dtype = nw.Enum(["a", "b"])
97+
result = roundtrip(dtype)
98+
assert result == dtype
99+
assert isinstance(result, DType)
100+
101+
102+
def test_serde_enum_v1_dtype(roundtrip: Identity) -> None:
103+
dtype = nw_v1.Enum()
104+
result = roundtrip(dtype)
105+
assert result == dtype
106+
assert isinstance(result, nw_v1.Enum)
107+
tp = type(result)
108+
with pytest.raises(TypeError):
109+
tp(["a", "b"]) # type: ignore[call-arg]
110+
111+
112+
def test_serde_enum_deferred(roundtrip: Identity) -> None:
113+
pytest.importorskip("polars")
114+
import polars as pl
115+
116+
categories = tuple(string.printable)
117+
dtype_pl = pl.Enum(categories)
118+
series_pl = pl.Series(categories).cast(dtype_pl).extend_constant(categories[5], 1000)
119+
series_nw = nw.from_native(series_pl, series_only=True)
120+
dtype_nw = series_nw.dtype
121+
assert isinstance(dtype_nw, nw.Enum)
122+
result = roundtrip(dtype_nw)
123+
assert isinstance(result, nw.Enum)
124+
assert result == dtype_nw
125+
assert result == series_nw.dtype
126+
assert dtype_nw == roundtrip(result)
127+
assert (
128+
type(result)(dtype_pl.categories).categories
129+
== roundtrip(result).categories
130+
== categories
131+
== result.categories
132+
== roundtrip(dtype_nw).categories
133+
)
134+
135+
136+
def test_serde_non_nested_dtypes(
137+
non_nested_type: type[NonNestedDType], roundtrip: Identity
138+
) -> None:
139+
dtype = non_nested_type()
140+
result = roundtrip(dtype)
141+
assert isinstance(result, DType)
142+
assert isinstance(result, non_nested_type)
143+
assert result == non_nested_type()
144+
assert result == non_nested_type
145+
146+
147+
def test_serde_nested_dtypes(
148+
nested_dtype: NestedOrEnumDType, roundtrip: Identity
149+
) -> None:
150+
result = roundtrip(nested_dtype)
151+
assert isinstance(result, DType)
152+
assert isinstance(result, nested_dtype.__class__)
153+
assert result == nested_dtype
154+
assert result == nested_dtype.base_type()

tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]:
4848
ConstructorLazy: TypeAlias = Callable[[Any], "NativeLazyFrame"]
4949
ConstructorPandasLike: TypeAlias = Callable[[Any], "pd.DataFrame"]
5050

51+
NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum"
52+
"""`DType`s which **cannot** be used as bare types."""
53+
5154
ID_PANDAS_LIKE = frozenset(
5255
("pandas", "pandas[nullable]", "pandas[pyarrow]", "modin", "modin[pyarrow]", "cudf")
5356
)

0 commit comments

Comments
 (0)