diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index e73e6ee93a..8ef5d30183 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -17,6 +17,7 @@ from collections.abc import Iterator, Sequence from typing import Any + import _typeshed from typing_extensions import Self, TypeIs from narwhals.typing import IntoDType, TimeUnit @@ -33,14 +34,12 @@ def _validate_dtype(dtype: DType | type[DType]) -> None: def _is_into_dtype(obj: Any) -> TypeIs[IntoDType]: return isinstance(obj, DType) or ( - isinstance(obj, type) - and issubclass(obj, DType) - and not issubclass(obj, NestedType) + isinstance(obj, DTypeClass) and not issubclass(obj, NestedType) ) def _is_nested_type(obj: Any) -> TypeIs[type[NestedType]]: - return isinstance(obj, type) and issubclass(obj, NestedType) + return isinstance(obj, DTypeClass) and issubclass(obj, NestedType) def _validate_into_dtype(dtype: Any) -> None: @@ -59,10 +58,40 @@ def _validate_into_dtype(dtype: Any) -> None: raise TypeError(msg) -class DType: - __slots__ = () +class DTypeClass(type): + """Metaclass for DType classes. - def __repr__(self) -> str: # pragma: no cover + - Nicely print classes. + - Ensure [`__slots__`] are always defined to prevent `__dict__` creation (empty by default). + + [`__slots__`]: https://docs.python.org/3/reference/datamodel.html#object.__slots__ + """ + + def __repr__(cls) -> str: + return cls.__name__ + + # https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/abc.pyi#L13-L19 + # https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/_typeshed/__init__.pyi#L36-L40 + # https://github.com/astral-sh/ruff/issues/8353#issuecomment-1786238311 + # https://docs.python.org/3/reference/datamodel.html#creating-the-class-object + def __new__( + metacls: type[_typeshed.Self], + cls_name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + /, + **kwds: Any, + ) -> _typeshed.Self: + namespace.setdefault("__slots__", ()) + return super().__new__(metacls, cls_name, bases, namespace, **kwds) # type: ignore[no-any-return, misc] + + +class DType(metaclass=DTypeClass): + """Base class for all Narwhals data types.""" + + __slots__ = () # NOTE: Keep this one defined manually for the type checker + + def __repr__(self) -> str: return self.__class__.__qualname__ @classmethod @@ -72,13 +101,11 @@ def base_type(cls) -> type[Self]: Examples: >>> import narwhals as nw >>> nw.Datetime("us").base_type() - - + Datetime >>> nw.String.base_type() - - + String >>> nw.List(nw.Int64).base_type() - + List """ return cls @@ -143,8 +170,6 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] >>> nw.Date() == nw.Datetime False """ - from narwhals._utils import isinstance_or_issubclass - return isinstance_or_issubclass(other, type(self)) def __hash__(self) -> int: @@ -154,44 +179,30 @@ def __hash__(self) -> int: class NumericType(DType): """Base class for numeric data types.""" - __slots__ = () - class IntegerType(NumericType): """Base class for integer data types.""" - __slots__ = () - class SignedIntegerType(IntegerType): """Base class for signed integer data types.""" - __slots__ = () - class UnsignedIntegerType(IntegerType): """Base class for unsigned integer data types.""" - __slots__ = () - class FloatType(NumericType): """Base class for float data types.""" - __slots__ = () - class TemporalType(DType): """Base class for temporal data types.""" - __slots__ = () - class NestedType(DType): """Base class for nested data types.""" - __slots__ = () - class Decimal(NumericType): """Decimal type. @@ -204,8 +215,6 @@ class Decimal(NumericType): Decimal """ - __slots__ = () - class Int128(SignedIntegerType): """128-bit signed integer type. @@ -226,8 +235,6 @@ class Int128(SignedIntegerType): Int128 """ - __slots__ = () - class Int64(SignedIntegerType): """64-bit signed integer type. @@ -241,8 +248,6 @@ class Int64(SignedIntegerType): Int64 """ - __slots__ = () - class Int32(SignedIntegerType): """32-bit signed integer type. @@ -256,8 +261,6 @@ class Int32(SignedIntegerType): Int32 """ - __slots__ = () - class Int16(SignedIntegerType): """16-bit signed integer type. @@ -271,8 +274,6 @@ class Int16(SignedIntegerType): Int16 """ - __slots__ = () - class Int8(SignedIntegerType): """8-bit signed integer type. @@ -286,8 +287,6 @@ class Int8(SignedIntegerType): Int8 """ - __slots__ = () - class UInt128(UnsignedIntegerType): """128-bit unsigned integer type. @@ -302,8 +301,6 @@ class UInt128(UnsignedIntegerType): UInt128 """ - __slots__ = () - class UInt64(UnsignedIntegerType): """64-bit unsigned integer type. @@ -317,8 +314,6 @@ class UInt64(UnsignedIntegerType): UInt64 """ - __slots__ = () - class UInt32(UnsignedIntegerType): """32-bit unsigned integer type. @@ -332,8 +327,6 @@ class UInt32(UnsignedIntegerType): UInt32 """ - __slots__ = () - class UInt16(UnsignedIntegerType): """16-bit unsigned integer type. @@ -347,8 +340,6 @@ class UInt16(UnsignedIntegerType): UInt16 """ - __slots__ = () - class UInt8(UnsignedIntegerType): """8-bit unsigned integer type. @@ -362,8 +353,6 @@ class UInt8(UnsignedIntegerType): UInt8 """ - __slots__ = () - class Float64(FloatType): """64-bit floating point type. @@ -377,8 +366,6 @@ class Float64(FloatType): Float64 """ - __slots__ = () - class Float32(FloatType): """32-bit floating point type. @@ -392,8 +379,6 @@ class Float32(FloatType): Float32 """ - __slots__ = () - class String(DType): """UTF-8 encoded string type. @@ -406,8 +391,6 @@ class String(DType): String """ - __slots__ = () - class Boolean(DType): """Boolean type. @@ -420,8 +403,6 @@ class Boolean(DType): Boolean """ - __slots__ = () - class Object(DType): """Data type for wrapping arbitrary Python objects. @@ -435,8 +416,6 @@ class Object(DType): Object """ - __slots__ = () - class Unknown(DType): """Type representing DataType values that could not be determined statically. @@ -449,10 +428,8 @@ class Unknown(DType): Unknown """ - __slots__ = () - -class _DatetimeMeta(type): +class _DatetimeMeta(DTypeClass): @property def time_unit(cls) -> TimeUnit: """Unit of time. Defaults to `'us'` (microseconds).""" @@ -546,7 +523,7 @@ def __repr__(self) -> str: # pragma: no cover return f"{class_name}(time_unit={self.time_unit!r}, time_zone={self.time_zone!r})" -class _DurationMeta(type): +class _DurationMeta(DTypeClass): @property def time_unit(cls) -> TimeUnit: """Unit of time. Defaults to `'us'` (microseconds).""" @@ -627,8 +604,6 @@ class Categorical(DType): Categorical """ - __slots__ = () - class Enum(DType): """A fixed categorical encoding of a unique set of strings. @@ -686,7 +661,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] >>> nw.Enum(["a", "b", "c"]) == nw.Enum True """ - if type(other) is type: + if type(other) is DTypeClass: return other is Enum return isinstance(other, type(self)) and self.categories == other.categories @@ -801,7 +776,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] >>> nw.Struct({"a": nw.Int64}) == nw.Struct True """ - if type(other) is type and issubclass(other, self.__class__): + if type(other) is DTypeClass and issubclass(other, self.__class__): return True if isinstance(other, self.__class__): return self.fields == other.fields @@ -864,7 +839,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] >>> nw.List(nw.Int64) == nw.List True """ - if type(other) is type and issubclass(other, self.__class__): + if type(other) is DTypeClass and issubclass(other, self.__class__): return True if isinstance(other, self.__class__): return self.inner == other.inner @@ -937,7 +912,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] >>> nw.Array(nw.Int64, 2) == nw.Array True """ - if type(other) is type and issubclass(other, self.__class__): + if type(other) is DTypeClass and issubclass(other, self.__class__): return True if isinstance(other, self.__class__): if self.shape != other.shape: @@ -972,8 +947,6 @@ class Date(TemporalType): Date """ - __slots__ = () - class Time(TemporalType): """Data type representing the time of day. @@ -999,8 +972,6 @@ class Time(TemporalType): Time """ - __slots__ = () - class Binary(DType): """Binary type. @@ -1024,5 +995,3 @@ class Binary(DType): >>> nw.from_native(rel).collect_schema()["t"] Binary """ - - __slots__ = () diff --git a/narwhals/stable/v1/_dtypes.py b/narwhals/stable/v1/_dtypes.py index 98b490b56f..5b4ea54958 100644 --- a/narwhals/stable/v1/_dtypes.py +++ b/narwhals/stable/v1/_dtypes.py @@ -12,6 +12,7 @@ Datetime as NwDatetime, Decimal, DType, + DTypeClass, Duration as NwDuration, Enum as NwEnum, Field, @@ -48,8 +49,6 @@ class Datetime(NwDatetime): - __slots__ = NwDatetime.__slots__ - @inherit_doc(NwDatetime) def __init__( self, time_unit: TimeUnit = "us", time_zone: str | timezone | None = None @@ -61,8 +60,6 @@ def __hash__(self) -> int: class Duration(NwDuration): - __slots__ = NwDuration.__slots__ - @inherit_doc(NwDuration) def __init__(self, time_unit: TimeUnit = "us") -> None: super().__init__(time_unit) @@ -85,13 +82,11 @@ class Enum(NwEnum): Enum """ - __slots__ = () - def __init__(self) -> None: super(NwEnum, self).__init__() def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] - if type(other) is type: + if type(other) is DTypeClass: return other in {type(self), NwEnum} return isinstance(other, type(self)) diff --git a/tests/dtypes/dtypes_test.py b/tests/dtypes/dtypes_test.py index a435d033f9..bed3adc396 100644 --- a/tests/dtypes/dtypes_test.py +++ b/tests/dtypes/dtypes_test.py @@ -65,7 +65,7 @@ def test_list_valid() -> None: assert dtype == nw.List assert dtype != nw.List(nw.Float32) assert dtype != nw.Duration - assert repr(dtype) == "List()" + assert repr(dtype) == "List(Int64)" dtype = nw.List(nw.List(nw.Int64)) assert dtype == nw.List(nw.List(nw.Int64)) assert dtype == nw.List @@ -80,7 +80,7 @@ def test_array_valid() -> None: assert dtype != nw.Array(nw.Int64, 3) assert dtype != nw.Array(nw.Float32, 2) assert dtype != nw.Duration - assert repr(dtype) == "Array(, shape=(2,))" + assert repr(dtype) == "Array(Int64, shape=(2,))" dtype = nw.Array(nw.Array(nw.Int64, 2), 2) assert dtype == nw.Array(nw.Array(nw.Int64, 2), 2) assert dtype == nw.Array @@ -100,7 +100,7 @@ def test_struct_valid() -> None: assert dtype == nw.Struct assert dtype != nw.Struct([nw.Field("a", nw.Float32)]) assert dtype != nw.Duration - assert repr(dtype) == "Struct({'a': })" + assert repr(dtype) == "Struct({'a': Int64})" dtype = nw.Struct({"a": nw.Int64, "b": nw.String}) assert dtype == nw.Struct({"a": nw.Int64, "b": nw.String}) @@ -119,7 +119,7 @@ def test_struct_reverse() -> None: def test_field_repr() -> None: dtype = nw.Field("a", nw.Int32) - assert repr(dtype) == "Field('a', )" + assert repr(dtype) == "Field('a', Int32)" def test_field_eq() -> None: @@ -500,6 +500,20 @@ def test_enum_hash() -> None: assert nw.Enum(["a", "b"]) not in {nw.Enum(["a", "b", "c"])} +@pytest.mark.xfail( + reason="https://github.com/narwhals-dev/narwhals/pull/3213#discussion_r2437271987" +) +@pytest.mark.parametrize("dtype_name", ["Datetime", "Duration", "Enum"]) +def test_dtype_repr_versioned(dtype_name: str) -> None: + from narwhals.stable import v1 as nw_v1 + + dtype_class_main = getattr(nw, dtype_name) + dtype_class_v1 = getattr(nw_v1, dtype_name) + + assert dtype_class_main is not dtype_class_v1 + assert repr(dtype_class_main) != repr(dtype_class_v1) + + def test_datetime_w_tz_duckdb() -> None: pytest.importorskip("duckdb") import duckdb diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 9121c3dc4c..7b97fcb1c6 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -18,7 +18,10 @@ ) if TYPE_CHECKING: + from collections.abc import Mapping + from narwhals._native import NativeSQLFrame + from narwhals.typing import NonNestedDType DATA = { "a": [1], @@ -38,7 +41,7 @@ "o": ["a"], "p": [1], } -SCHEMA = { +SCHEMA: Mapping[str, type[NonNestedDType]] = { "a": nw.Int64, "b": nw.Int32, "c": nw.Int16, @@ -86,7 +89,7 @@ def test_cast(constructor: Constructor) -> None: nw.col(col_).cast(dtype) for col_, dtype in schema.items() ) - cast_map = { + cast_map: Mapping[str, type[NonNestedDType]] = { "a": nw.Int32, "b": nw.Int16, "c": nw.Int8, @@ -134,7 +137,7 @@ def test_cast_series( .lazy() .collect() ) - cast_map = { + cast_map: Mapping[str, type[NonNestedDType]] = { "a": nw.Int32, "b": nw.Int16, "c": nw.Int8,