diff --git a/tests/conftest.py b/tests/conftest.py index c823432c5c..3f50c9717d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import pytest +import narwhals as nw from narwhals._utils import Implementation, generate_temporary_column_name from tests.utils import ID_PANDAS_LIKE, PANDAS_VERSION, pyspark_session, sqlframe_session @@ -26,11 +27,17 @@ from narwhals._spark_like.dataframe import SQLFrameDataFrame from narwhals._typing import EagerAllowed - from narwhals.typing import NativeDataFrame, NativeLazyFrame - from tests.utils import Constructor, ConstructorEager, ConstructorLazy + from narwhals.typing import NativeDataFrame, NativeLazyFrame, NonNestedDType + from tests.utils import ( + Constructor, + ConstructorEager, + ConstructorLazy, + NestedOrEnumDType, + ) Data: TypeAlias = "dict[str, list[Any]]" + MIN_PANDAS_NULLABLE_VERSION = (2,) # When testing cudf.pandas in Kaggle, we get an error if we try to run @@ -321,3 +328,50 @@ def eager_backend(request: pytest.FixtureRequest) -> EagerAllowed: def eager_implementation(request: pytest.FixtureRequest) -> EagerAllowed: """Use if a test is heavily parametric, skips `str` backend.""" return request.param # type: ignore[no-any-return] + + +@pytest.fixture( + params=[ + nw.Boolean, + nw.Categorical, + nw.Date, + nw.Datetime, + nw.Decimal, + nw.Duration, + nw.Float32, + nw.Float64, + nw.Int8, + nw.Int16, + nw.Int32, + nw.Int64, + nw.Int128, + nw.Object, + nw.String, + nw.Time, + nw.UInt8, + nw.UInt16, + nw.UInt32, + nw.UInt64, + nw.UInt128, + nw.Unknown, + nw.Binary, + ], + ids=lambda tp: tp.__name__, +) +def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: + tp_dtype: type[NonNestedDType] = request.param + return tp_dtype + + +@pytest.fixture( + params=[ + nw.List(nw.Float32), + nw.Array(nw.String, 2), + nw.Struct({"a": nw.Boolean}), + nw.Enum(["beluga", "narwhal"]), + ], + ids=lambda obj: type(obj).__name__, +) +def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType: + dtype: NestedOrEnumDType = request.param + return dtype diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index fe2dec98e7..801cd6453f 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -17,58 +17,8 @@ if TYPE_CHECKING: from collections.abc import Iterable - from typing_extensions import TypeAlias - from narwhals.typing import IntoFrame, IntoSeries, NonNestedDType - from tests.utils import Constructor, ConstructorPandasLike - -NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum" -"""`DType`s which **cannot** be used as bare types.""" - - -@pytest.fixture( - params=[ - nw.Boolean, - nw.Categorical, - nw.Date, - nw.Datetime, - nw.Decimal, - nw.Duration, - nw.Float32, - nw.Float64, - nw.Int8, - nw.Int16, - nw.Int32, - nw.Int64, - nw.Int128, - nw.Object, - nw.String, - nw.Time, - nw.UInt8, - nw.UInt16, - nw.UInt32, - nw.UInt64, - nw.UInt128, - nw.Unknown, - nw.Binary, - ], - ids=lambda tp: tp.__name__, -) -def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: - return request.param # type: ignore[no-any-return] - - -@pytest.fixture( - params=[ - nw.List(nw.Float32), - nw.Array(nw.String, 2), - nw.Struct({"a": nw.Boolean}), - nw.Enum(["beluga", "narwhal"]), - ], - ids=lambda obj: type(obj).__name__, -) -def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType: - return request.param # type: ignore[no-any-return] + from tests.utils import Constructor, ConstructorPandasLike, NestedOrEnumDType @pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) diff --git a/tests/serde_test.py b/tests/serde_test.py new file mode 100644 index 0000000000..dae845364f --- /dev/null +++ b/tests/serde_test.py @@ -0,0 +1,154 @@ +"""Serialization tests, based on [py-polars/tests/unit/test_serde.py]. + +See also [Pickling Class Instances](https://docs.python.org/3/library/pickle.html#pickling-class-instances). + +[py-polars/tests/unit/test_serde.py]: https://github.com/pola-rs/polars/blob/a143eb0d7077ee9da2ce209a19c21d7f82228081/py-polars/tests/unit/test_serde.py +""" + +from __future__ import annotations + +import pickle +import string + +# ruff: noqa: S301 +from typing import TYPE_CHECKING, Protocol, TypeVar + +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 +from narwhals.dtypes import DType +from narwhals.typing import IntoDType, NonNestedDType, TimeUnit + +if TYPE_CHECKING: + from narwhals.typing import DTypes + from tests.utils import NestedOrEnumDType + + +IntoDTypeT = TypeVar("IntoDTypeT", bound=IntoDType) + + +namespaces = pytest.mark.parametrize("namespace", [nw, nw_v1]) +time_units = pytest.mark.parametrize("time_unit", ["ns", "us", "ms", "s"]) + + +class Identity(Protocol): + def __call__(self, obj: IntoDTypeT, /) -> IntoDTypeT: ... + + +def _roundtrip_pickle(protocol: int | None = None) -> Identity: + def fn(obj: IntoDTypeT, /) -> IntoDTypeT: + result: IntoDTypeT = pickle.loads(pickle.dumps(obj, protocol)) + return result + + return fn + + +@pytest.fixture( + params=[_roundtrip_pickle(), _roundtrip_pickle(4), _roundtrip_pickle(5)], + ids=["pickle-None", "pickle-4", "pickle-5"], +) +def roundtrip(request: pytest.FixtureRequest) -> Identity: + fn: Identity = request.param + return fn + + +@namespaces +@time_units +def test_serde_datetime_dtype( + namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity +) -> None: + dtype = namespace.Datetime(time_unit) + result = roundtrip(dtype) + assert result == dtype + + +@namespaces +@time_units +def test_serde_duration_dtype( + namespace: DTypes, time_unit: TimeUnit, roundtrip: Identity +) -> None: + dtype = namespace.Duration(time_unit) + result = roundtrip(dtype) + assert result == dtype + + +def test_serde_doubly_nested_struct_dtype(roundtrip: Identity) -> None: + dtype = nw.Struct([nw.Field("a", nw.List(nw.String))]) + result = roundtrip(dtype) + assert result == dtype + + +def test_serde_doubly_nested_array_dtype(roundtrip: Identity) -> None: + dtype = nw.Array(nw.Array(nw.Int32(), 2), 3) + result = roundtrip(dtype) + assert result == dtype + + +def test_serde_dtype_class(roundtrip: Identity) -> None: + dtype_class = nw.Datetime + result = roundtrip(dtype_class) + assert result == dtype_class + assert isinstance(result, type) + + +def test_serde_enum_dtype(roundtrip: Identity) -> None: + dtype = nw.Enum(["a", "b"]) + result = roundtrip(dtype) + assert result == dtype + assert isinstance(result, DType) + + +def test_serde_enum_v1_dtype(roundtrip: Identity) -> None: + dtype = nw_v1.Enum() + result = roundtrip(dtype) + assert result == dtype + assert isinstance(result, nw_v1.Enum) + tp = type(result) + with pytest.raises(TypeError): + tp(["a", "b"]) # type: ignore[call-arg] + + +def test_serde_enum_deferred(roundtrip: Identity) -> None: + pytest.importorskip("polars") + import polars as pl + + categories = tuple(string.printable) + dtype_pl = pl.Enum(categories) + series_pl = pl.Series(categories).cast(dtype_pl).extend_constant(categories[5], 1000) + series_nw = nw.from_native(series_pl, series_only=True) + dtype_nw = series_nw.dtype + assert isinstance(dtype_nw, nw.Enum) + result = roundtrip(dtype_nw) + assert isinstance(result, nw.Enum) + assert result == dtype_nw + assert result == series_nw.dtype + assert dtype_nw == roundtrip(result) + assert ( + type(result)(dtype_pl.categories).categories + == roundtrip(result).categories + == categories + == result.categories + == roundtrip(dtype_nw).categories + ) + + +def test_serde_non_nested_dtypes( + non_nested_type: type[NonNestedDType], roundtrip: Identity +) -> None: + dtype = non_nested_type() + result = roundtrip(dtype) + assert isinstance(result, DType) + assert isinstance(result, non_nested_type) + assert result == non_nested_type() + assert result == non_nested_type + + +def test_serde_nested_dtypes( + nested_dtype: NestedOrEnumDType, roundtrip: Identity +) -> None: + result = roundtrip(nested_dtype) + assert isinstance(result, DType) + assert isinstance(result, nested_dtype.__class__) + assert result == nested_dtype + assert result == nested_dtype.base_type() diff --git a/tests/utils.py b/tests/utils.py index 4e06c35063..4fc11492a1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -48,6 +48,9 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: ConstructorLazy: TypeAlias = Callable[[Any], "NativeLazyFrame"] ConstructorPandasLike: TypeAlias = Callable[[Any], "pd.DataFrame"] +NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum" +"""`DType`s which **cannot** be used as bare types.""" + ID_PANDAS_LIKE = frozenset( ("pandas", "pandas[nullable]", "pandas[pyarrow]", "modin", "modin[pyarrow]", "cudf") )