From 819c56455594ff92b716bb7822461b08f94a0c1d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 10 Oct 2025 17:41:16 +0000 Subject: [PATCH 01/12] perf: Add `__slots__` to all `DType`s Closes #3185 --- narwhals/dtypes.py | 12 ++++++++++++ tests/expr_and_series/cast_test.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 831f704be8..7ea4a9e690 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -60,6 +60,8 @@ def _validate_into_dtype(dtype: Any) -> None: class DType: + __slots__ = () + def __repr__(self) -> str: # pragma: no cover return self.__class__.__qualname__ @@ -438,6 +440,8 @@ class Datetime(TemporalType, metaclass=_DatetimeMeta): Datetime(time_unit='ms', time_zone='Africa/Accra') """ + __slots__ = ("time_unit", "time_zone") + def __init__( self, time_unit: TimeUnit = "us", time_zone: str | timezone | None = None ) -> None: @@ -521,6 +525,8 @@ class Duration(TemporalType, metaclass=_DurationMeta): Duration(time_unit='ms') """ + __slots__ = ("time_unit",) + def __init__(self, time_unit: TimeUnit = "us") -> None: if time_unit not in {"s", "ms", "us", "ns"}: msg = ( @@ -586,6 +592,8 @@ class Enum(DType): Enum(categories=['beluga', 'narwhal', 'orca']) """ + __slots__ = ("_cached_categories", "_delayed_categories") + def __init__(self, categories: Iterable[str] | type[enum.Enum]) -> None: self._delayed_categories: _DeferredIterable[str] | None = None self._cached_categories: tuple[str, ...] | None = None @@ -655,6 +663,7 @@ class Field: [Field('a', Int64), Field('b', List(String))] """ + __slots__ = ("dtype", "name") name: str """The name of the field within its parent `Struct`.""" dtype: IntoDType @@ -713,6 +722,7 @@ class Struct(NestedType): Struct({'a': Int64, 'b': List(String)}) """ + __slots__ = ("fields",) fields: list[Field] """The fields that make up the struct.""" @@ -782,6 +792,7 @@ class List(NestedType): List(String) """ + __slots__ = ("inner",) inner: IntoDType """The DType of the values within each list.""" @@ -832,6 +843,7 @@ class Array(NestedType): Array(Int32, shape=(2,)) """ + __slots__ = ("inner", "shape", "size") inner: IntoDType """The DType of the values within each array.""" size: int diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 03745758a7..549fd76b3a 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -371,7 +371,7 @@ def test_cast_typing_invalid() -> None: # feel free to update the types used # See (https://github.com/narwhals-dev/narwhals/pull/2654#discussion_r2142263770) - with pytest.raises(AttributeError): + with pytest.raises(TypeError): df.select(a.cast(nw.Struct)) # type: ignore[arg-type] with pytest.raises(AttributeError): @@ -389,7 +389,7 @@ def test_cast_typing_invalid() -> None: with pytest.raises((ValueError, AttributeError)): df.select(a.cast(nw.Struct({"a": nw.Int16, "b": nw.Enum}))) # type: ignore[dict-item] - with pytest.raises(AttributeError): + with pytest.raises(TypeError): df.select(a.cast(nw.List(nw.Struct))) # type: ignore[arg-type] with pytest.raises(AttributeError): From 2fe3d8c18c42ecdb92cb906fc71d97bf342eb25d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:04:47 +0000 Subject: [PATCH 02/12] fix: Ensure `__dict__` isn't created https://docs.python.org/3/reference/datamodel.html#slots --- narwhals/dtypes.py | 56 +++++++++++++++++++++++++ tests/dtypes_test.py | 99 ++++++++++++++++++++++++++++---------------- 2 files changed, 120 insertions(+), 35 deletions(-) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 7ea4a9e690..e73e6ee93a 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -154,30 +154,44 @@ def __hash__(self) -> int: class NumericType(DType): """Base class for numeric data types.""" + __slots__ = () + class IntegerType(NumericType): """Base class for integer data types.""" + __slots__ = () + class SignedIntegerType(IntegerType): """Base class for signed integer data types.""" + __slots__ = () + class UnsignedIntegerType(IntegerType): """Base class for unsigned integer data types.""" + __slots__ = () + class FloatType(NumericType): """Base class for float data types.""" + __slots__ = () + class TemporalType(DType): """Base class for temporal data types.""" + __slots__ = () + class NestedType(DType): """Base class for nested data types.""" + __slots__ = () + class Decimal(NumericType): """Decimal type. @@ -190,6 +204,8 @@ class Decimal(NumericType): Decimal """ + __slots__ = () + class Int128(SignedIntegerType): """128-bit signed integer type. @@ -210,6 +226,8 @@ class Int128(SignedIntegerType): Int128 """ + __slots__ = () + class Int64(SignedIntegerType): """64-bit signed integer type. @@ -223,6 +241,8 @@ class Int64(SignedIntegerType): Int64 """ + __slots__ = () + class Int32(SignedIntegerType): """32-bit signed integer type. @@ -236,6 +256,8 @@ class Int32(SignedIntegerType): Int32 """ + __slots__ = () + class Int16(SignedIntegerType): """16-bit signed integer type. @@ -249,6 +271,8 @@ class Int16(SignedIntegerType): Int16 """ + __slots__ = () + class Int8(SignedIntegerType): """8-bit signed integer type. @@ -262,6 +286,8 @@ class Int8(SignedIntegerType): Int8 """ + __slots__ = () + class UInt128(UnsignedIntegerType): """128-bit unsigned integer type. @@ -276,6 +302,8 @@ class UInt128(UnsignedIntegerType): UInt128 """ + __slots__ = () + class UInt64(UnsignedIntegerType): """64-bit unsigned integer type. @@ -289,6 +317,8 @@ class UInt64(UnsignedIntegerType): UInt64 """ + __slots__ = () + class UInt32(UnsignedIntegerType): """32-bit unsigned integer type. @@ -302,6 +332,8 @@ class UInt32(UnsignedIntegerType): UInt32 """ + __slots__ = () + class UInt16(UnsignedIntegerType): """16-bit unsigned integer type. @@ -315,6 +347,8 @@ class UInt16(UnsignedIntegerType): UInt16 """ + __slots__ = () + class UInt8(UnsignedIntegerType): """8-bit unsigned integer type. @@ -328,6 +362,8 @@ class UInt8(UnsignedIntegerType): UInt8 """ + __slots__ = () + class Float64(FloatType): """64-bit floating point type. @@ -341,6 +377,8 @@ class Float64(FloatType): Float64 """ + __slots__ = () + class Float32(FloatType): """32-bit floating point type. @@ -354,6 +392,8 @@ class Float32(FloatType): Float32 """ + __slots__ = () + class String(DType): """UTF-8 encoded string type. @@ -366,6 +406,8 @@ class String(DType): String """ + __slots__ = () + class Boolean(DType): """Boolean type. @@ -378,6 +420,8 @@ class Boolean(DType): Boolean """ + __slots__ = () + class Object(DType): """Data type for wrapping arbitrary Python objects. @@ -391,6 +435,8 @@ class Object(DType): Object """ + __slots__ = () + class Unknown(DType): """Type representing DataType values that could not be determined statically. @@ -403,6 +449,8 @@ class Unknown(DType): Unknown """ + __slots__ = () + class _DatetimeMeta(type): @property @@ -579,6 +627,8 @@ class Categorical(DType): Categorical """ + __slots__ = () + class Enum(DType): """A fixed categorical encoding of a unique set of strings. @@ -922,6 +972,8 @@ class Date(TemporalType): Date """ + __slots__ = () + class Time(TemporalType): """Data type representing the time of day. @@ -947,6 +999,8 @@ class Time(TemporalType): Time """ + __slots__ = () + class Binary(DType): """Binary type. @@ -970,3 +1024,5 @@ class Binary(DType): >>> nw.from_native(rel).collect_schema()["t"] Binary """ + + __slots__ = () diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index d63384647f..ab25d89925 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -17,9 +17,57 @@ if TYPE_CHECKING: from collections.abc import Iterable + from typing_extensions import TypeAlias + from narwhals.typing import IntoFrame, IntoSeries, NonNestedDType from tests.utils import Constructor, ConstructorPandasLike +NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum" +"""`DType`s which **cannot** be used as bare types.""" + + +@pytest.fixture( + params=[ + nw.Boolean, + nw.Categorical, + nw.Date, + nw.Datetime, + nw.Decimal, + nw.Duration, + nw.Float32, + nw.Float64, + nw.Int8, + nw.Int16, + nw.Int32, + nw.Int64, + nw.Int128, + nw.Object, + nw.String, + nw.Time, + nw.UInt8, + nw.UInt16, + nw.UInt32, + nw.UInt64, + nw.UInt128, + nw.Unknown, + nw.Binary, + ] +) +def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: + return request.param # type: ignore[no-any-return] + + +@pytest.fixture( + params=[ + nw.List(nw.Float32), + nw.Array(nw.String, 2), + nw.Struct({"a": nw.Boolean}), + nw.Enum(["beluga", "narwhal"]), + ] +) +def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType: + return request.param # type: ignore[no-any-return] + @pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) @pytest.mark.parametrize("time_zone", ["Europe/Rome", timezone.utc, None]) @@ -534,43 +582,13 @@ def test_datetime_w_tz_pyspark() -> None: # pragma: no cover assert result["a"] == nw.List(nw.Datetime("us", "UTC")) -@pytest.mark.parametrize( - "dtype", - [ - nw.Boolean, - nw.Categorical, - nw.Date, - nw.Datetime, - nw.Decimal, - nw.Duration, - nw.Float32, - nw.Float64, - nw.Int8, - nw.Int16, - nw.Int32, - nw.Int64, - nw.Int128, - nw.Object, - nw.String, - nw.Time, - nw.UInt8, - nw.UInt16, - nw.UInt32, - nw.UInt64, - nw.UInt128, - nw.Unknown, - nw.Binary, - ], -) -def test_dtype_base_type_non_nested(dtype: type[NonNestedDType]) -> None: - assert dtype.base_type() is dtype().base_type() +def test_dtype_base_type_non_nested(non_nested_type: type[NonNestedDType]) -> None: + assert non_nested_type.base_type() is non_nested_type().base_type() -def test_dtype_base_type_nested() -> None: - assert nw.List.base_type() is nw.List(nw.Float32).base_type() - assert nw.Array.base_type() is nw.Array(nw.String, 2).base_type() - assert nw.Struct.base_type() is nw.Struct({"a": nw.Boolean}).base_type() - assert nw.Enum.base_type() is nw.Enum(["beluga", "narwhal"]).base_type() +def test_dtype_base_type_nested(nested_dtype: NestedOrEnumDType) -> None: + base = nested_dtype.base_type() + assert base.base_type() == nested_dtype.base_type() @pytest.mark.parametrize( @@ -594,3 +612,14 @@ def test_pandas_datetime_ignored_time_unit_warns( ctx = does_not_warn() if PANDAS_VERSION >= (2,) else context with ctx: df.select(expr) + + +def test_dtype___slots___non_nested(non_nested_type: type[NonNestedDType]) -> None: + dtype = non_nested_type() + with pytest.raises(AttributeError): + dtype.i_dont_exist = 100 # type: ignore[union-attr] + + +def test_dtype___slots___nested(nested_dtype: NestedOrEnumDType) -> None: + with pytest.raises(AttributeError): + nested_dtype.i_dont_exist = 999 # type: ignore[union-attr] From aa29d41a3eb54b65f6a23d60dd606caa7e75d115 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:24:55 +0000 Subject: [PATCH 03/12] fix: Rinse/repeat `v1` slots --- narwhals/stable/v1/_dtypes.py | 6 ++++++ tests/v1_test.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/narwhals/stable/v1/_dtypes.py b/narwhals/stable/v1/_dtypes.py index 060980c562..98b490b56f 100644 --- a/narwhals/stable/v1/_dtypes.py +++ b/narwhals/stable/v1/_dtypes.py @@ -48,6 +48,8 @@ class Datetime(NwDatetime): + __slots__ = NwDatetime.__slots__ + @inherit_doc(NwDatetime) def __init__( self, time_unit: TimeUnit = "us", time_zone: str | timezone | None = None @@ -59,6 +61,8 @@ def __hash__(self) -> int: class Duration(NwDuration): + __slots__ = NwDuration.__slots__ + @inherit_doc(NwDuration) def __init__(self, time_unit: TimeUnit = "us") -> None: super().__init__(time_unit) @@ -81,6 +85,8 @@ class Enum(NwEnum): Enum """ + __slots__ = () + def __init__(self) -> None: super(NwEnum, self).__init__() diff --git a/tests/v1_test.py b/tests/v1_test.py index 37d5dc1779..67660f3614 100644 --- a/tests/v1_test.py +++ b/tests/v1_test.py @@ -50,6 +50,7 @@ from typing_extensions import assert_type from narwhals._typing import EagerAllowed + from narwhals.dtypes import DType from narwhals.stable.v1.typing import IntoDataFrameT from narwhals.typing import IntoDType, _1DArray, _2DArray from tests.utils import Constructor, ConstructorEager @@ -1115,3 +1116,11 @@ def test_mode_different_lengths(constructor_eager: ConstructorEager) -> None: df = nw_v1.from_native(constructor_eager({"a": [1, 1, 2], "b": [4, 5, 6]})) with pytest.raises(ShapeError): df.select(nw_v1.col("a", "b").mode()) + + +@pytest.mark.parametrize( + "dtype", [nw_v1.Datetime(), nw_v1.Duration(), nw_v1.Enum()], ids=str +) +def test_dtype___slots__(dtype: DType) -> None: + with pytest.raises(AttributeError): + dtype.i_also_dont_exist = 528329 # type: ignore[attr-defined] From 6dada91369474c3fed0703e4d3895768729742aa Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 11 Oct 2025 21:31:33 +0000 Subject: [PATCH 04/12] test: Check `__dict__` is a no-no --- tests/dtypes_test.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index ab25d89925..fe2dec98e7 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -51,7 +51,8 @@ nw.UInt128, nw.Unknown, nw.Binary, - ] + ], + ids=lambda tp: tp.__name__, ) def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: return request.param # type: ignore[no-any-return] @@ -63,7 +64,8 @@ def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: nw.Array(nw.String, 2), nw.Struct({"a": nw.Boolean}), nw.Enum(["beluga", "narwhal"]), - ] + ], + ids=lambda obj: type(obj).__name__, ) def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType: return request.param # type: ignore[no-any-return] @@ -618,8 +620,15 @@ def test_dtype___slots___non_nested(non_nested_type: type[NonNestedDType]) -> No dtype = non_nested_type() with pytest.raises(AttributeError): dtype.i_dont_exist = 100 # type: ignore[union-attr] + with pytest.raises(AttributeError): + dtype.__dict__ # noqa: B018 + _ = dtype.__slots__ def test_dtype___slots___nested(nested_dtype: NestedOrEnumDType) -> None: with pytest.raises(AttributeError): nested_dtype.i_dont_exist = 999 # type: ignore[union-attr] + with pytest.raises(AttributeError): + nested_dtype.__dict__ # noqa: B018 + slots = nested_dtype.__slots__ + assert len(slots) != 0, slots From b2f6976a0da3868caa9d84fd1da266ec063b20e2 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 12 Oct 2025 09:56:40 +0200 Subject: [PATCH 05/12] add precommit check for slotted classes --- .pre-commit-config.yaml | 5 ++ utils/check_api_reference.py | 89 +++++++++++++++++----------------- utils/check_slotted_classes.py | 57 ++++++++++++++++++++++ 3 files changed, 107 insertions(+), 44 deletions(-) create mode 100644 utils/check_slotted_classes.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e862bf1741..9becf12c3b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,6 +47,11 @@ repos: entry: python -m utils.check_api_reference language: python additional_dependencies: [polars] + - id: check-slotted-classes + name: check-slotted-classes + pass_filenames: false + entry: python -m utils.check_slotted_classes + language: python - id: imports-are-banned name: import are banned (use `get_pandas` instead of `import pandas`) entry: python utils/import_check.py diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 3233b24b1c..a32de124f2 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -1,3 +1,4 @@ +# ruff: noqa: T201 from __future__ import annotations import inspect @@ -118,48 +119,48 @@ def read_documented_members(source: str | Path) -> list[str]: documented = read_documented_members(DIR_API_REF / "narwhals.md") if missing := set(top_level_functions).difference(documented).difference({"annotations"}): - print("top-level functions: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("top-level functions: not documented") + print(missing) ret = 1 if extra := set(documented).difference(top_level_functions): - print("top-level functions: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("top-level functions: outdated") + print(extra) ret = 1 # DataFrame methods dataframe_methods = list(iter_api_reference_names(nw.DataFrame)) documented = read_documented_members(DIR_API_REF / "dataframe.md") if missing := set(dataframe_methods).difference(documented): - print("DataFrame: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("DataFrame: not documented") + print(missing) ret = 1 if extra := set(documented).difference(dataframe_methods): - print("DataFrame: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("DataFrame: outdated") + print(extra) ret = 1 # LazyFrame methods lazyframe_methods = list(iter_api_reference_names(nw.LazyFrame)) documented = read_documented_members(DIR_API_REF / "lazyframe.md") if missing := set(lazyframe_methods).difference(documented): - print("LazyFrame: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("LazyFrame: not documented") + print(missing) ret = 1 if extra := set(documented).difference(lazyframe_methods): - print("LazyFrame: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("LazyFrame: outdated") + print(extra) ret = 1 # Series methods series_methods = list(iter_api_reference_names(nw.Series)) documented = read_documented_members(DIR_API_REF / "series.md") if missing := set(series_methods).difference(documented).difference(NAMESPACES): - print("Series: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("Series: not documented") + print(missing) ret = 1 if extra := set(documented).difference(series_methods): - print("Series: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("Series: outdated") + print(extra) ret = 1 # Series.{cat, dt, list, str} methods @@ -171,24 +172,24 @@ def read_documented_members(source: str | Path) -> list[str]: ] documented = read_documented_members(DIR_API_REF / f"series_{namespace}.md") if missing := set(series_ns_methods).difference(documented): - print(f"Series.{namespace}: not documented") # noqa: T201 - print(missing) # noqa: T201 + print(f"Series.{namespace}: not documented") + print(missing) ret = 1 if extra := set(documented).difference(series_ns_methods): - print(f"Series.{namespace}: outdated") # noqa: T201 - print(extra) # noqa: T201 + print(f"Series.{namespace}: outdated") + print(extra) ret = 1 # Expr methods expr_methods = list(iter_api_reference_names(nw.Expr)) documented = read_documented_members(DIR_API_REF / "expr.md") if missing := set(expr_methods).difference(documented).difference(NAMESPACES): - print("Expr: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("Expr: not documented") + print(missing) ret = 1 if extra := set(documented).difference(expr_methods): - print("Expr: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("Expr: outdated") + print(extra) ret = 1 # Expr.{cat, dt, list, name, str} methods @@ -200,24 +201,24 @@ def read_documented_members(source: str | Path) -> list[str]: ] documented = read_documented_members(DIR_API_REF / f"expr_{namespace}.md") if missing := set(expr_ns_methods).difference(documented): - print(f"Expr.{namespace}: not documented") # noqa: T201 - print(missing) # noqa: T201 + print(f"Expr.{namespace}: not documented") + print(missing) ret = 1 if extra := set(documented).difference(expr_ns_methods): - print(f"Expr.{namespace}: outdated") # noqa: T201 - print(extra) # noqa: T201 + print(f"Expr.{namespace}: outdated") + print(extra) ret = 1 # DTypes dtypes = list(iter_api_reference_names_dtypes(nw.dtypes)) documented = read_documented_members(DIR_API_REF / "dtypes.md") if missing := set(dtypes).difference(documented): - print("Dtype: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("Dtype: not documented") + print(missing) ret = 1 if extra := set(documented).difference(dtypes): - print("Dtype: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("Dtype: outdated") + print(extra) ret = 1 # Schema @@ -228,22 +229,22 @@ def read_documented_members(source: str | Path) -> list[str]: .difference(documented) .difference(iter_api_reference_names(OrderedDict)) ): - print("Schema: not documented") # noqa: T201 - print(missing) # noqa: T201 + print("Schema: not documented") + print(missing) ret = 1 if extra := set(documented).difference(schema_methods): - print("Schema: outdated") # noqa: T201 - print(extra) # noqa: T201 + print("Schema: outdated") + print(extra) ret = 1 # Check Expr vs Series if missing := set(expr_methods).difference(series_methods).difference(EXPR_ONLY_METHODS): - print("In Expr but not in Series") # noqa: T201 - print(missing) # noqa: T201 + print("In Expr but not in Series") + print(missing) ret = 1 if extra := set(series_methods).difference(expr_methods).difference(SERIES_ONLY_METHODS): - print("In Series but not in Expr") # noqa: T201 - print(extra) # noqa: T201 + print("In Series but not in Expr") + print(extra) ret = 1 # Check Expr vs Series internal methods @@ -259,12 +260,12 @@ def read_documented_members(source: str | Path) -> list[str]: if not i[0].isupper() and i[0] != "_" ] if missing := set(expr_internal).difference(series_internal): - print(f"In Expr.{namespace} but not in Series.{namespace}") # noqa: T201 - print(missing) # noqa: T201 + print(f"In Expr.{namespace} but not in Series.{namespace}") + print(missing) ret = 1 if extra := set(series_internal).difference(expr_internal): - print(f"In Series.{namespace} but not in Expr.{namespace}") # noqa: T201 - print(extra) # noqa: T201 + print(f"In Series.{namespace} but not in Expr.{namespace}") + print(extra) ret = 1 sys.exit(ret) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py new file mode 100644 index 0000000000..c1c12b39dd --- /dev/null +++ b/utils/check_slotted_classes.py @@ -0,0 +1,57 @@ +# ruff: noqa: T201 + +from __future__ import annotations + +import inspect +import sys +from itertools import chain +from typing import TYPE_CHECKING + +import narwhals.dtypes as dtypes_main +import narwhals.stable.v1.dtypes as v1_dtypes +import narwhals.stable.v2.dtypes as v2_dtypes + +if TYPE_CHECKING: + from collections.abc import Generator + from types import ModuleType, UnionType + + from typing_extensions import TypeAlias + + if sys.version_info >= (3, 10): + _ClassInfo: TypeAlias = type | UnionType | tuple["_ClassInfo", ...] + else: + _ClassInfo: TypeAlias = type | tuple["_ClassInfo", ...] + + +base_dtype = dtypes_main.DType +field_type = dtypes_main.Field + + +def get_unslotted_classes( + module: ModuleType, bases: _ClassInfo +) -> Generator[tuple[str, ModuleType], None, None]: + """Find classes in a `module` that inherit from `bases` but don't define `__slots__`.""" + return ( + (name, module) + for name, cls in inspect.getmembers(module) + if isinstance(cls, type) + and issubclass(cls, bases) + and "__slots__" not in cls.__dict__ + ) + + +ret = 0 +unslotted_classes = tuple( + chain.from_iterable( + get_unslotted_classes(mod, bases=(base_dtype, field_type)) + for mod in (dtypes_main, v1_dtypes, v2_dtypes) + ) +) + +if unslotted_classes: + ret = 1 + msg = "The following classes are expected to define `__slots__` but they don't:\n" + cls_list = "\n".join(f" * {c[0]} from {c[1]}" for c in unslotted_classes) + print(f"{msg}{cls_list}") + +sys.exit(ret) From 7117d381bbfaaa8526c7be883f55cb2b77998230 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 12 Oct 2025 10:08:49 +0200 Subject: [PATCH 06/12] more explicit unpacking names --- utils/check_slotted_classes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index c1c12b39dd..59d7f49502 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -12,7 +12,7 @@ import narwhals.stable.v2.dtypes as v2_dtypes if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Iterator from types import ModuleType, UnionType from typing_extensions import TypeAlias @@ -29,7 +29,7 @@ def get_unslotted_classes( module: ModuleType, bases: _ClassInfo -) -> Generator[tuple[str, ModuleType], None, None]: +) -> Iterator[tuple[str, ModuleType]]: """Find classes in a `module` that inherit from `bases` but don't define `__slots__`.""" return ( (name, module) @@ -51,7 +51,7 @@ def get_unslotted_classes( if unslotted_classes: ret = 1 msg = "The following classes are expected to define `__slots__` but they don't:\n" - cls_list = "\n".join(f" * {c[0]} from {c[1]}" for c in unslotted_classes) + cls_list = "\n".join(f" * {name} from {mod}" for name, mod in unslotted_classes) print(f"{msg}{cls_list}") sys.exit(ret) From 05db297a6c7d90d42ebdad52d65388d0ab2d729b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 10:42:53 +0000 Subject: [PATCH 07/12] ci(ruff): Add (`T201`) ignore to all `utils/*` --- pyproject.toml | 4 ++-- utils/check_api_reference.py | 1 - utils/check_dist_content.py | 4 ++-- utils/check_docstrings.py | 4 ++-- utils/check_slotted_classes.py | 2 -- utils/import_check.py | 4 ++-- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da62cde2e4..27623fb567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -210,8 +210,8 @@ extend-ignore-names = [ "PLR0916", # too-many-boolean-expressions ] "tpch/tests/*" = ["S101"] -"utils/*" = ["S311"] -"utils/bump_version.py" = ["S603", "S607", "T201"] +"utils/*" = ["S311", "T201"] +"utils/bump_version.py" = ["S603", "S607"] "tpch/execute/*" = ["T201"] "tpch/notebooks/*" = [ "ANN001", diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index a32de124f2..03ad115a5e 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -1,4 +1,3 @@ -# ruff: noqa: T201 from __future__ import annotations import inspect diff --git a/utils/check_dist_content.py b/utils/check_dist_content.py index 9a9922ef8d..62a387e5f6 100644 --- a/utils/check_dist_content.py +++ b/utils/check_dist_content.py @@ -18,7 +18,7 @@ } if unexpected_wheel_dirs: - print(f"🚨 Unexpected directories in wheel: {unexpected_wheel_dirs}") # noqa: T201 + print(f"🚨 Unexpected directories in wheel: {unexpected_wheel_dirs}") sys.exit(1) with TarFile.open(sdist_path, mode="r:gz") as sdist_file: @@ -35,7 +35,7 @@ } if unexpected_sdist_dirs := sdist_dirs - allowed_sdist_dirs: - print(f"🚨 Unexpected directories or files in sdist: {unexpected_sdist_dirs}") # noqa: T201 + print(f"🚨 Unexpected directories or files in sdist: {unexpected_sdist_dirs}") sys.exit(1) sys.exit(0) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 12a6d6faaf..a3049eabd2 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -77,11 +77,11 @@ def report_errors(errors: list[str], temp_files: list[tuple[Path, str]]) -> None if not errors: return - print("❌ Ruff issues found in examples:\n") # noqa: T201 + print("❌ Ruff issues found in examples:\n") for line in errors: for temp_file, original_context in temp_files: if str(temp_file) in line: - print(f"{original_context}{line.replace(str(temp_file), '')}") # noqa: T201 + print(f"{original_context}{line.replace(str(temp_file), '')}") break diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index 59d7f49502..1d1b90df15 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -1,5 +1,3 @@ -# ruff: noqa: T201 - from __future__ import annotations import inspect diff --git a/utils/import_check.py b/utils/import_check.py index 17f6616fa6..d292b40790 100644 --- a/utils/import_check.py +++ b/utils/import_check.py @@ -56,7 +56,7 @@ def visit_Import(self, node: ast.Import) -> None: and alias.name not in self.allowed_imports and "# ignore-banned-import" not in self.lines[node.lineno - 1] ): - print( # noqa: T201 + print( f"{self.file_name}:{node.lineno}:{node.col_offset}: found {alias.name} import" ) self.found_import = True @@ -69,7 +69,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: and "# ignore-banned-import" not in self.lines[node.lineno - 1] and node.module not in self.allowed_imports ): - print( # noqa: T201 + print( f"{self.file_name}:{node.lineno}:{node.col_offset}: found {node.module} import" ) self.found_import = True From e96f4633c485084cc83d30ba7fe89e230ec41c44 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 11:30:10 +0000 Subject: [PATCH 08/12] refactor(suggestion): Reuse `get_dtype_backend_test` logic See https://github.com/narwhals-dev/narwhals/pull/3194/files#r2423653800 --- utils/check_slotted_classes.py | 48 ++++++++++++---------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index 1d1b90df15..4cdac1558e 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -1,55 +1,39 @@ from __future__ import annotations -import inspect import sys -from itertools import chain -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar -import narwhals.dtypes as dtypes_main -import narwhals.stable.v1.dtypes as v1_dtypes -import narwhals.stable.v2.dtypes as v2_dtypes +from narwhals.dtypes import DType, Field if TYPE_CHECKING: from collections.abc import Iterator - from types import ModuleType, UnionType - from typing_extensions import TypeAlias - if sys.version_info >= (3, 10): - _ClassInfo: TypeAlias = type | UnionType | tuple["_ClassInfo", ...] - else: - _ClassInfo: TypeAlias = type | tuple["_ClassInfo", ...] +T_co = TypeVar("T_co", covariant=True) -base_dtype = dtypes_main.DType -field_type = dtypes_main.Field +def _iter_descendants(*bases: type[T_co]) -> Iterator[type[T_co]]: + for base in bases: + if children := base.__subclasses__(): + yield from _iter_descendants(*children) + else: + yield base -def get_unslotted_classes( - module: ModuleType, bases: _ClassInfo -) -> Iterator[tuple[str, ModuleType]]: - """Find classes in a `module` that inherit from `bases` but don't define `__slots__`.""" - return ( - (name, module) - for name, cls in inspect.getmembers(module) - if isinstance(cls, type) - and issubclass(cls, bases) - and "__slots__" not in cls.__dict__ - ) +def iter_unslotted_classes(*bases: type[T_co]) -> Iterator[str]: + """Find classes in that inherit from `bases` but don't define `__slots__`.""" + for tp in sorted(set(_iter_descendants(*bases)), key=repr): + if "__slots__" not in tp.__dict__: + yield f"{tp.__module__}.{tp.__name__}" ret = 0 -unslotted_classes = tuple( - chain.from_iterable( - get_unslotted_classes(mod, bases=(base_dtype, field_type)) - for mod in (dtypes_main, v1_dtypes, v2_dtypes) - ) -) +unslotted_classes = tuple(iter_unslotted_classes(DType, Field)) if unslotted_classes: ret = 1 msg = "The following classes are expected to define `__slots__` but they don't:\n" - cls_list = "\n".join(f" * {name} from {mod}" for name, mod in unslotted_classes) + cls_list = "\n".join(f" * {name}" for name in unslotted_classes) print(f"{msg}{cls_list}") sys.exit(ret) From 3795a4937454f0e24a372d4ebbdc8482e35dccfe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 15:51:06 +0000 Subject: [PATCH 09/12] fix: Include more than leaves, make sure in scope https://github.com/narwhals-dev/narwhals/pull/3194#discussion_r2423744444 --- utils/check_slotted_classes.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index 4cdac1558e..abe33ed8ab 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -3,6 +3,7 @@ import sys from typing import TYPE_CHECKING, TypeVar +from narwhals._utils import Version, qualified_type_name from narwhals.dtypes import DType, Field if TYPE_CHECKING: @@ -11,20 +12,24 @@ T_co = TypeVar("T_co", covariant=True) +# NOTE: For `__subclasses__` to work, all modules that descendants are defined in must be imported +_ = Version.MAIN.dtypes +_ = Version.V1.dtypes +_ = Version.V2.dtypes + def _iter_descendants(*bases: type[T_co]) -> Iterator[type[T_co]]: for base in bases: + yield base if children := base.__subclasses__(): yield from _iter_descendants(*children) - else: - yield base def iter_unslotted_classes(*bases: type[T_co]) -> Iterator[str]: """Find classes in that inherit from `bases` but don't define `__slots__`.""" - for tp in sorted(set(_iter_descendants(*bases)), key=repr): + for tp in sorted(set(_iter_descendants(*bases)), key=qualified_type_name): if "__slots__" not in tp.__dict__: - yield f"{tp.__module__}.{tp.__name__}" + yield qualified_type_name(tp) ret = 0 From d9c052055ab0f95d6e82c51c65229a846511bc05 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 15:55:14 +0000 Subject: [PATCH 10/12] perf: huge speedup Noticed how slow it was during debugging (too much repeating) --- utils/check_slotted_classes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index abe33ed8ab..71704840a4 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -19,10 +19,13 @@ def _iter_descendants(*bases: type[T_co]) -> Iterator[type[T_co]]: + seen = set[T_co]() for base in bases: yield base - if children := base.__subclasses__(): - yield from _iter_descendants(*children) + if (children := (base.__subclasses__())) and ( + unseen := set(children).difference(seen) + ): + yield from _iter_descendants(*unseen) def iter_unslotted_classes(*bases: type[T_co]) -> Iterator[str]: From 28e3ef7bba44e3d6078dccb80e5871a6a179a712 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 16:00:45 +0000 Subject: [PATCH 11/12] ci: Link to what `__slots__` even are --- utils/check_slotted_classes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index 71704840a4..d979e44281 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -42,6 +42,8 @@ def iter_unslotted_classes(*bases: type[T_co]) -> Iterator[str]: ret = 1 msg = "The following classes are expected to define `__slots__` but they don't:\n" cls_list = "\n".join(f" * {name}" for name in unslotted_classes) + url = "https://docs.python.org/3/reference/datamodel.html#slots" + hint = f"Hint: See for detail {url!r}" print(f"{msg}{cls_list}") sys.exit(ret) From a404d0d013c93b1c2949fa10e72aadb42c5da569 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 21:49:51 +0000 Subject: [PATCH 12/12] Even friendlier error message Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --- utils/check_slotted_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/check_slotted_classes.py b/utils/check_slotted_classes.py index d979e44281..ffa73269e0 100644 --- a/utils/check_slotted_classes.py +++ b/utils/check_slotted_classes.py @@ -43,7 +43,7 @@ def iter_unslotted_classes(*bases: type[T_co]) -> Iterator[str]: msg = "The following classes are expected to define `__slots__` but they don't:\n" cls_list = "\n".join(f" * {name}" for name in unslotted_classes) url = "https://docs.python.org/3/reference/datamodel.html#slots" - hint = f"Hint: See for detail {url!r}" + hint = f"Hint: For more details see {url!r}" print(f"{msg}{cls_list}") sys.exit(ret)