Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 45 additions & 76 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Iterator, Sequence
from typing import Any

import _typeshed
from typing_extensions import Self, TypeIs

from narwhals.typing import IntoDType, TimeUnit
Expand All @@ -33,14 +34,12 @@ def _validate_dtype(dtype: DType | type[DType]) -> None:

def _is_into_dtype(obj: Any) -> TypeIs[IntoDType]:
return isinstance(obj, DType) or (
isinstance(obj, type)
and issubclass(obj, DType)
and not issubclass(obj, NestedType)
isinstance(obj, DTypeClass) and not issubclass(obj, NestedType)
)


def _is_nested_type(obj: Any) -> TypeIs[type[NestedType]]:
return isinstance(obj, type) and issubclass(obj, NestedType)
return isinstance(obj, DTypeClass) and issubclass(obj, NestedType)


def _validate_into_dtype(dtype: Any) -> None:
Expand All @@ -59,10 +58,40 @@ def _validate_into_dtype(dtype: Any) -> None:
raise TypeError(msg)


class DType:
__slots__ = ()
class DTypeClass(type):
"""Metaclass for DType classes.

def __repr__(self) -> str: # pragma: no cover
- Nicely print classes.
- Ensure [`__slots__`] are always defined to prevent `__dict__` creation (empty by default).

[`__slots__`]: https://docs.python.org/3/reference/datamodel.html#object.__slots__
"""

def __repr__(cls) -> str:
return cls.__name__

# https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/abc.pyi#L13-L19
# https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/_typeshed/__init__.pyi#L36-L40
# https://github.com/astral-sh/ruff/issues/8353#issuecomment-1786238311
# https://docs.python.org/3/reference/datamodel.html#creating-the-class-object
def __new__(
metacls: type[_typeshed.Self],
cls_name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
/,
**kwds: Any,
) -> _typeshed.Self:
namespace.setdefault("__slots__", ())
return super().__new__(metacls, cls_name, bases, namespace, **kwds) # type: ignore[no-any-return, misc]


class DType(metaclass=DTypeClass):
"""Base class for all Narwhals data types."""

__slots__ = () # NOTE: Keep this one defined manually for the type checker

def __repr__(self) -> str:
return self.__class__.__qualname__

@classmethod
Expand All @@ -72,13 +101,11 @@ def base_type(cls) -> type[Self]:
Examples:
>>> import narwhals as nw
>>> nw.Datetime("us").base_type()
<class 'narwhals.dtypes.Datetime'>

Datetime
>>> nw.String.base_type()
<class 'narwhals.dtypes.String'>

String
>>> nw.List(nw.Int64).base_type()
<class 'narwhals.dtypes.List'>
List
"""
return cls

Expand Down Expand Up @@ -143,8 +170,6 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
>>> nw.Date() == nw.Datetime
False
"""
from narwhals._utils import isinstance_or_issubclass

return isinstance_or_issubclass(other, type(self))

def __hash__(self) -> int:
Expand All @@ -154,44 +179,30 @@ def __hash__(self) -> int:
class NumericType(DType):
"""Base class for numeric data types."""

__slots__ = ()


class IntegerType(NumericType):
"""Base class for integer data types."""

__slots__ = ()


class SignedIntegerType(IntegerType):
"""Base class for signed integer data types."""

__slots__ = ()


class UnsignedIntegerType(IntegerType):
"""Base class for unsigned integer data types."""

__slots__ = ()


class FloatType(NumericType):
"""Base class for float data types."""

__slots__ = ()


class TemporalType(DType):
"""Base class for temporal data types."""

__slots__ = ()


class NestedType(DType):
"""Base class for nested data types."""

__slots__ = ()


class Decimal(NumericType):
"""Decimal type.
Expand All @@ -204,8 +215,6 @@ class Decimal(NumericType):
Decimal
"""

__slots__ = ()


class Int128(SignedIntegerType):
"""128-bit signed integer type.
Expand All @@ -226,8 +235,6 @@ class Int128(SignedIntegerType):
Int128
"""

__slots__ = ()


class Int64(SignedIntegerType):
"""64-bit signed integer type.
Expand All @@ -241,8 +248,6 @@ class Int64(SignedIntegerType):
Int64
"""

__slots__ = ()


class Int32(SignedIntegerType):
"""32-bit signed integer type.
Expand All @@ -256,8 +261,6 @@ class Int32(SignedIntegerType):
Int32
"""

__slots__ = ()


class Int16(SignedIntegerType):
"""16-bit signed integer type.
Expand All @@ -271,8 +274,6 @@ class Int16(SignedIntegerType):
Int16
"""

__slots__ = ()


class Int8(SignedIntegerType):
"""8-bit signed integer type.
Expand All @@ -286,8 +287,6 @@ class Int8(SignedIntegerType):
Int8
"""

__slots__ = ()


class UInt128(UnsignedIntegerType):
"""128-bit unsigned integer type.
Expand All @@ -302,8 +301,6 @@ class UInt128(UnsignedIntegerType):
UInt128
"""

__slots__ = ()


class UInt64(UnsignedIntegerType):
"""64-bit unsigned integer type.
Expand All @@ -317,8 +314,6 @@ class UInt64(UnsignedIntegerType):
UInt64
"""

__slots__ = ()


class UInt32(UnsignedIntegerType):
"""32-bit unsigned integer type.
Expand All @@ -332,8 +327,6 @@ class UInt32(UnsignedIntegerType):
UInt32
"""

__slots__ = ()


class UInt16(UnsignedIntegerType):
"""16-bit unsigned integer type.
Expand All @@ -347,8 +340,6 @@ class UInt16(UnsignedIntegerType):
UInt16
"""

__slots__ = ()


class UInt8(UnsignedIntegerType):
"""8-bit unsigned integer type.
Expand All @@ -362,8 +353,6 @@ class UInt8(UnsignedIntegerType):
UInt8
"""

__slots__ = ()


class Float64(FloatType):
"""64-bit floating point type.
Expand All @@ -377,8 +366,6 @@ class Float64(FloatType):
Float64
"""

__slots__ = ()


class Float32(FloatType):
"""32-bit floating point type.
Expand All @@ -392,8 +379,6 @@ class Float32(FloatType):
Float32
"""

__slots__ = ()


class String(DType):
"""UTF-8 encoded string type.
Expand All @@ -406,8 +391,6 @@ class String(DType):
String
"""

__slots__ = ()


class Boolean(DType):
"""Boolean type.
Expand All @@ -420,8 +403,6 @@ class Boolean(DType):
Boolean
"""

__slots__ = ()


class Object(DType):
"""Data type for wrapping arbitrary Python objects.
Expand All @@ -435,8 +416,6 @@ class Object(DType):
Object
"""

__slots__ = ()


class Unknown(DType):
"""Type representing DataType values that could not be determined statically.
Expand All @@ -449,10 +428,8 @@ class Unknown(DType):
Unknown
"""

__slots__ = ()


class _DatetimeMeta(type):
class _DatetimeMeta(DTypeClass):
@property
def time_unit(cls) -> TimeUnit:
"""Unit of time. Defaults to `'us'` (microseconds)."""
Expand Down Expand Up @@ -546,7 +523,7 @@ def __repr__(self) -> str: # pragma: no cover
return f"{class_name}(time_unit={self.time_unit!r}, time_zone={self.time_zone!r})"


class _DurationMeta(type):
class _DurationMeta(DTypeClass):
@property
def time_unit(cls) -> TimeUnit:
"""Unit of time. Defaults to `'us'` (microseconds)."""
Expand Down Expand Up @@ -627,8 +604,6 @@ class Categorical(DType):
Categorical
"""

__slots__ = ()


class Enum(DType):
"""A fixed categorical encoding of a unique set of strings.
Expand Down Expand Up @@ -686,7 +661,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
>>> nw.Enum(["a", "b", "c"]) == nw.Enum
True
"""
if type(other) is type:
if type(other) is DTypeClass:
return other is Enum
return isinstance(other, type(self)) and self.categories == other.categories

Expand Down Expand Up @@ -801,7 +776,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
>>> nw.Struct({"a": nw.Int64}) == nw.Struct
True
"""
if type(other) is type and issubclass(other, self.__class__):
if type(other) is DTypeClass and issubclass(other, self.__class__):
return True
if isinstance(other, self.__class__):
return self.fields == other.fields
Expand Down Expand Up @@ -864,7 +839,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
>>> nw.List(nw.Int64) == nw.List
True
"""
if type(other) is type and issubclass(other, self.__class__):
if type(other) is DTypeClass and issubclass(other, self.__class__):
return True
if isinstance(other, self.__class__):
return self.inner == other.inner
Expand Down Expand Up @@ -937,7 +912,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
>>> nw.Array(nw.Int64, 2) == nw.Array
True
"""
if type(other) is type and issubclass(other, self.__class__):
if type(other) is DTypeClass and issubclass(other, self.__class__):
return True
if isinstance(other, self.__class__):
if self.shape != other.shape:
Expand Down Expand Up @@ -972,8 +947,6 @@ class Date(TemporalType):
Date
"""

__slots__ = ()


class Time(TemporalType):
"""Data type representing the time of day.
Expand All @@ -999,8 +972,6 @@ class Time(TemporalType):
Time
"""

__slots__ = ()


class Binary(DType):
"""Binary type.
Expand All @@ -1024,5 +995,3 @@ class Binary(DType):
>>> nw.from_native(rel).collect_schema()["t"]
Binary
"""

__slots__ = ()
Loading
Loading