Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest

import narwhals as nw
from narwhals._utils import Implementation, generate_temporary_column_name
from tests.utils import ID_PANDAS_LIKE, PANDAS_VERSION, pyspark_session, sqlframe_session

Expand All @@ -26,11 +27,17 @@

from narwhals._spark_like.dataframe import SQLFrameDataFrame
from narwhals._typing import EagerAllowed
from narwhals.typing import NativeDataFrame, NativeLazyFrame
from tests.utils import Constructor, ConstructorEager, ConstructorLazy
from narwhals.typing import NativeDataFrame, NativeLazyFrame, NonNestedDType
from tests.utils import (
Constructor,
ConstructorEager,
ConstructorLazy,
NestedOrEnumDType,
)

Data: TypeAlias = "dict[str, list[Any]]"


MIN_PANDAS_NULLABLE_VERSION = (2,)

# When testing cudf.pandas in Kaggle, we get an error if we try to run
Expand Down Expand Up @@ -321,3 +328,50 @@ def eager_backend(request: pytest.FixtureRequest) -> EagerAllowed:
def eager_implementation(request: pytest.FixtureRequest) -> EagerAllowed:
"""Use if a test is heavily parametric, skips `str` backend."""
return request.param # type: ignore[no-any-return]


@pytest.fixture(
params=[
nw.Boolean,
nw.Categorical,
nw.Date,
nw.Datetime,
nw.Decimal,
nw.Duration,
nw.Float32,
nw.Float64,
nw.Int8,
nw.Int16,
nw.Int32,
nw.Int64,
nw.Int128,
nw.Object,
nw.String,
nw.Time,
nw.UInt8,
nw.UInt16,
nw.UInt32,
nw.UInt64,
nw.UInt128,
nw.Unknown,
nw.Binary,
],
ids=lambda tp: tp.__name__,
)
def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]:
tp_dtype: type[NonNestedDType] = request.param
return tp_dtype


@pytest.fixture(
params=[
nw.List(nw.Float32),
nw.Array(nw.String, 2),
nw.Struct({"a": nw.Boolean}),
nw.Enum(["beluga", "narwhal"]),
],
ids=lambda obj: type(obj).__name__,
)
def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType:
dtype: NestedOrEnumDType = request.param
return dtype
52 changes: 1 addition & 51 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,8 @@
if TYPE_CHECKING:
from collections.abc import Iterable

from typing_extensions import TypeAlias

from narwhals.typing import IntoFrame, IntoSeries, NonNestedDType
from tests.utils import Constructor, ConstructorPandasLike

NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum"
"""`DType`s which **cannot** be used as bare types."""


@pytest.fixture(
params=[
nw.Boolean,
nw.Categorical,
nw.Date,
nw.Datetime,
nw.Decimal,
nw.Duration,
nw.Float32,
nw.Float64,
nw.Int8,
nw.Int16,
nw.Int32,
nw.Int64,
nw.Int128,
nw.Object,
nw.String,
nw.Time,
nw.UInt8,
nw.UInt16,
nw.UInt32,
nw.UInt64,
nw.UInt128,
nw.Unknown,
nw.Binary,
],
ids=lambda tp: tp.__name__,
)
def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]:
return request.param # type: ignore[no-any-return]


@pytest.fixture(
params=[
nw.List(nw.Float32),
nw.Array(nw.String, 2),
nw.Struct({"a": nw.Boolean}),
nw.Enum(["beluga", "narwhal"]),
],
ids=lambda obj: type(obj).__name__,
)
def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType:
return request.param # type: ignore[no-any-return]
from tests.utils import Constructor, ConstructorPandasLike, NestedOrEnumDType


@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"])
Expand Down
154 changes: 154 additions & 0 deletions tests/serde_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Serialization tests, based on [py-polars/tests/unit/test_serde.py].

See also [Pickling Class Instances](https://docs.python.org/3/library/pickle.html#pickling-class-instances).

[py-polars/tests/unit/test_serde.py]: https://github.com/pola-rs/polars/blob/a143eb0d7077ee9da2ce209a19c21d7f82228081/py-polars/tests/unit/test_serde.py
"""

from __future__ import annotations

import pickle
import string

# ruff: noqa: S301
from typing import TYPE_CHECKING, Protocol, TypeVar

import pytest

import narwhals as nw
import narwhals.stable.v1 as nw_v1
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, NonNestedDType, TimeUnit

if TYPE_CHECKING:
from narwhals.typing import DTypes
from tests.utils import NestedOrEnumDType


IntoDTypeT = TypeVar("IntoDTypeT", bound=IntoDType)


namespaces = pytest.mark.parametrize("namespace", [nw, nw_v1])
time_units = pytest.mark.parametrize("time_unit", ["ns", "us", "ms", "s"])


class Identity(Protocol):
def __call__(self, obj: IntoDTypeT, /) -> IntoDTypeT: ...


def _roundtrip_pickle(protocol: int | None = None) -> Identity:
def fn(obj: IntoDTypeT, /) -> IntoDTypeT:
result: IntoDTypeT = pickle.loads(pickle.dumps(obj, protocol))
return result

return fn


@pytest.fixture(
params=[_roundtrip_pickle(), _roundtrip_pickle(4), _roundtrip_pickle(5)],
ids=["pickle-None", "pickle-4", "pickle-5"],
)
def roundtrip(request: pytest.FixtureRequest) -> Identity:
fn: Identity = request.param
return fn
Comment on lines +35 to +53
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea I had here is we can pretty easily extend this by adding more functions that can roundtrip.

pickle is all that I've used, but the protocols which are being called can be used elsewhere

import narwhals as nw

dtype = nw.Struct({"a": nw.List(nw.Array(nw.String, 5))})

>>> dtype.__reduce_ex__(5)
(<function copyreg.__newobj__(cls, *args)>,
 (narwhals.dtypes.Struct,),
 (None,
  {'fields': [Field('a', List(Array(<class 'narwhals.dtypes.String'>, shape=(5,))))]}),
 None,
 None)


>>> dtype.__getstate__()
(None,
 {'fields': [Field('a', List(Array(<class 'narwhals.dtypes.String'>, shape=(5,))))]})

Copy link
Member Author

@dangotbanned dangotbanned Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Barely on topic, but noticable from the example I used

@FBruzzesi nested DTypes are where the metaclass repr from polars would be nice

import narwhals as nw
>>> nw.Struct({"a": nw.List(nw.Array(nw.String, 5))})
Struct({'a': List(Array(<class 'narwhals.dtypes.String'>, shape=(5,)))})
import polars as pl
>>> pl.Struct({"a": pl.List(pl.Array(pl.String, 5))})
Struct({'a': List(Array(String, shape=(5,)))})



@namespaces
@time_units
def test_serde_datetime_dtype(
namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity
) -> None:
dtype = namespace.Datetime(time_unit)
result = roundtrip(dtype)
assert result == namespace.Datetime(time_unit)


@namespaces
@time_units
def test_serde_duration_dtype(
namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity
) -> None:
dtype = namespace.Duration(time_unit)
result = roundtrip(dtype)
assert result == namespace.Duration(time_unit)


def test_serde_doubly_nested_struct_dtype(roundtrip: Identity) -> None:
dtype = nw.Struct([nw.Field("a", nw.List(nw.String))])
result = roundtrip(dtype)
assert result == nw.Struct([nw.Field("a", nw.List(nw.String))])


def test_serde_doubly_nested_array_dtype(roundtrip: Identity) -> None:
dtype = nw.Array(nw.Array(nw.Int32(), 2), 3)
result = roundtrip(dtype)
assert result == nw.Array(nw.Array(nw.Int32(), 2), 3)


def test_serde_dtype_class(roundtrip: Identity) -> None:
dtype_class = nw.Datetime
result = roundtrip(dtype_class)
assert result == dtype_class
assert isinstance(result, type)


def test_serde_enum_dtype(roundtrip: Identity) -> None:
dtype = nw.Enum(["a", "b"])
result = roundtrip(dtype)
assert result == dtype
assert isinstance(result, DType)


def test_serde_enum_v1_dtype(roundtrip: Identity) -> None:
dtype = nw_v1.Enum()
result = roundtrip(dtype)
assert result == dtype
assert isinstance(result, nw_v1.Enum)
tp = type(result)
with pytest.raises(TypeError):
tp(["a", "b"]) # type: ignore[call-arg]


def test_serde_enum_deferred(roundtrip: Identity) -> None:
pytest.importorskip("polars")
import polars as pl

categories = tuple(string.printable)
dtype_pl = pl.Enum(categories)
series_pl = pl.Series(categories).cast(dtype_pl).extend_constant(categories[5], 1000)
series_nw = nw.from_native(series_pl, series_only=True)
dtype_nw = series_nw.dtype
assert isinstance(dtype_nw, nw.Enum)
result = roundtrip(dtype_nw)
assert isinstance(result, nw.Enum)
assert result == dtype_nw
assert result == series_nw.dtype
assert dtype_nw == roundtrip(result)
assert (
type(result)(dtype_pl.categories).categories
== roundtrip(result).categories
== categories
== result.categories
== roundtrip(dtype_nw).categories
)


def test_serde_non_nested_dtypes(
non_nested_type: type[NonNestedDType], roundtrip: Identity
) -> None:
dtype = non_nested_type()
result = roundtrip(dtype)
assert isinstance(result, DType)
assert isinstance(result, non_nested_type)
assert result == non_nested_type()
assert result == non_nested_type


def test_serde_nested_dtypes(
nested_dtype: NestedOrEnumDType, roundtrip: Identity
) -> None:
result = roundtrip(nested_dtype)
assert isinstance(result, DType)
assert isinstance(result, nested_dtype.__class__)
assert result == nested_dtype
assert result == nested_dtype.base_type()
3 changes: 3 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]:
ConstructorLazy: TypeAlias = Callable[[Any], "NativeLazyFrame"]
ConstructorPandasLike: TypeAlias = Callable[[Any], "pd.DataFrame"]

NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum"
"""`DType`s which **cannot** be used as bare types."""

ID_PANDAS_LIKE = frozenset(
("pandas", "pandas[nullable]", "pandas[pyarrow]", "modin", "modin[pyarrow]", "cudf")
)
Expand Down
Loading