Skip to content

Commit bf544d4

Browse files
authored
fix(expr-ir): Ensure only __slots__, and not __dict__ too (#3201)
1 parent 1a433a9 commit bf544d4

File tree

11 files changed

+331
-115
lines changed

11 files changed

+331
-115
lines changed

narwhals/_plan/_immutable.py

Lines changed: 65 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,21 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Literal, TypeVar
3+
from typing import TYPE_CHECKING
4+
5+
# ruff: noqa: N806
6+
from narwhals._plan._meta import ImmutableMeta
47

58
if TYPE_CHECKING:
6-
from collections.abc import Iterator
7-
from typing import Any, Callable
8-
9-
from typing_extensions import Never, Self, dataclass_transform
10-
11-
else:
12-
# https://docs.python.org/3/library/typing.html#typing.dataclass_transform
13-
def dataclass_transform(
14-
*,
15-
eq_default: bool = True,
16-
order_default: bool = False,
17-
kw_only_default: bool = False,
18-
frozen_default: bool = False,
19-
field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
20-
**kwargs: Any,
21-
) -> Callable[[T], T]:
22-
def decorator(cls_or_fn: T) -> T:
23-
cls_or_fn.__dataclass_transform__ = {
24-
"eq_default": eq_default,
25-
"order_default": order_default,
26-
"kw_only_default": kw_only_default,
27-
"frozen_default": frozen_default,
28-
"field_specifiers": field_specifiers,
29-
"kwargs": kwargs,
30-
}
31-
return cls_or_fn
32-
33-
return decorator
34-
35-
36-
T = TypeVar("T")
37-
_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__"
38-
39-
40-
@dataclass_transform(kw_only_default=True, frozen_default=True)
41-
class Immutable:
9+
from collections.abc import Iterable, Iterator
10+
from typing import Any, ClassVar, Final
11+
12+
from typing_extensions import Never, Self
13+
14+
15+
_HASH_NAME: Final = "__immutable_hash_value__"
16+
17+
18+
class Immutable(metaclass=ImmutableMeta):
4219
"""A poor man's frozen dataclass.
4320
4421
- Keyword-only constructor (IDE supported)
@@ -49,40 +26,43 @@ class Immutable:
4926
[`copy.replace`]: https://docs.python.org/3.13/library/copy.html#copy.replace
5027
"""
5128

52-
__slots__ = (_IMMUTABLE_HASH_NAME,)
53-
__immutable_hash_value__: int
29+
__slots__ = (_HASH_NAME,)
30+
if not TYPE_CHECKING:
31+
# NOTE: Trying to avoid this being added to synthesized `__init__`
32+
# Seems to be the only difference when decorating the metaclass
33+
__immutable_hash_value__: int
5434

55-
@property
56-
def __immutable_keys__(self) -> Iterator[str]:
57-
slots: tuple[str, ...] = self.__slots__
58-
for name in slots:
59-
if name != _IMMUTABLE_HASH_NAME:
60-
yield name
35+
__immutable_keys__: ClassVar[tuple[str, ...]]
6136

6237
@property
6338
def __immutable_values__(self) -> Iterator[Any]:
39+
"""Override to configure hash seed."""
40+
getattr_ = getattr
6441
for name in self.__immutable_keys__:
65-
yield getattr(self, name)
42+
yield getattr_(self, name)
6643

6744
@property
6845
def __immutable_items__(self) -> Iterator[tuple[str, Any]]:
46+
getattr_ = getattr
6947
for name in self.__immutable_keys__:
70-
yield name, getattr(self, name)
48+
yield name, getattr_(self, name)
7149

7250
@property
7351
def __immutable_hash__(self) -> int:
74-
if hasattr(self, _IMMUTABLE_HASH_NAME):
75-
return self.__immutable_hash_value__
76-
hash_value = hash((self.__class__, *self.__immutable_values__))
77-
object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value)
78-
return self.__immutable_hash_value__
52+
HASH = _HASH_NAME
53+
if hasattr(self, HASH):
54+
hash_value: int = getattr(self, HASH)
55+
else:
56+
hash_value = hash((self.__class__, *self.__immutable_values__))
57+
object.__setattr__(self, HASH, hash_value)
58+
return hash_value
7959

8060
def __setattr__(self, name: str, value: Never) -> Never:
8161
msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set."
8262
raise AttributeError(msg)
8363

8464
def __replace__(self, **changes: Any) -> Self:
85-
"""https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415
65+
"""https://docs.python.org/3.13/library/copy.html#copy.replace."""
8666
if len(changes) == 1:
8767
# The most common case is a single field replacement.
8868
# Iff that field happens to be equal, we can noop, preserving the current object's hash.
@@ -96,13 +76,6 @@ def __replace__(self, **changes: Any) -> Self:
9676
changes[name] = value_current
9777
return type(self)(**changes)
9878

99-
def __init_subclass__(cls, *args: Any, **kwds: Any) -> None:
100-
super().__init_subclass__(*args, **kwds)
101-
if cls.__slots__:
102-
...
103-
else:
104-
cls.__slots__ = ()
105-
10679
def __hash__(self) -> int:
10780
return self.__immutable_hash__
10881

@@ -111,35 +84,26 @@ def __eq__(self, other: object) -> bool:
11184
return True
11285
if type(self) is not type(other):
11386
return False
87+
getattr_ = getattr
11488
return all(
115-
getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__
89+
getattr_(self, key) == getattr_(other, key) for key in self.__immutable_keys__
11690
)
11791

11892
def __str__(self) -> str:
11993
fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__)
12094
return f"{type(self).__name__}({fields})"
12195

12296
def __init__(self, **kwds: Any) -> None:
123-
required: set[str] = set(self.__immutable_keys__)
124-
if not required and not kwds:
125-
# NOTE: Fastpath for empty slots
126-
...
127-
elif required == set(kwds):
128-
for name, value in kwds.items():
129-
object.__setattr__(self, name, value)
130-
elif missing := required.difference(kwds):
131-
msg = (
132-
f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n"
133-
f"but missing values for {sorted(missing)!r}"
134-
)
135-
raise TypeError(msg)
136-
else:
137-
extra = set(kwds).difference(required)
138-
msg = (
139-
f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n"
140-
f"but got unknown arguments {sorted(extra)!r}"
141-
)
142-
raise TypeError(msg)
97+
if (keys := self.__immutable_keys__) or kwds:
98+
required = set(keys)
99+
if required == kwds.keys():
100+
object__setattr__ = object.__setattr__
101+
for name, value in kwds.items():
102+
object__setattr__(self, name, value)
103+
elif missing := required.difference(kwds):
104+
raise _init_missing_error(self, required, missing)
105+
else:
106+
raise _init_extra_error(self, required, set(kwds).difference(required))
143107

144108

145109
def _field_str(name: str, value: Any) -> str:
@@ -149,3 +113,23 @@ def _field_str(name: str, value: Any) -> str:
149113
if isinstance(value, str):
150114
return f"{name}={value!r}"
151115
return f"{name}={value}"
116+
117+
118+
def _init_missing_error(
119+
obj: object, required: Iterable[str], missing: Iterable[str]
120+
) -> TypeError:
121+
msg = (
122+
f"{type(obj).__name__!r} requires attributes {sorted(required)!r}, \n"
123+
f"but missing values for {sorted(missing)!r}"
124+
)
125+
return TypeError(msg)
126+
127+
128+
def _init_extra_error(
129+
obj: object, required: Iterable[str], extra: Iterable[str]
130+
) -> TypeError:
131+
msg = (
132+
f"{type(obj).__name__!r} only supports attributes {sorted(required)!r}, \n"
133+
f"but got unknown arguments {sorted(extra)!r}"
134+
)
135+
return TypeError(msg)

narwhals/_plan/_meta.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Metaclasses and other unholy metaprogramming nonsense."""
2+
3+
from __future__ import annotations
4+
5+
# ruff: noqa: N806
6+
from itertools import chain
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
from typing import Any, Final, TypeVar
12+
13+
import _typeshed
14+
from typing_extensions import dataclass_transform
15+
16+
from narwhals._plan.typing import Seq
17+
18+
T = TypeVar("T")
19+
20+
else:
21+
# https://docs.python.org/3/library/typing.html#typing.dataclass_transform
22+
def dataclass_transform(
23+
*,
24+
eq_default: bool = True,
25+
order_default: bool = False,
26+
kw_only_default: bool = False,
27+
frozen_default: bool = False,
28+
field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
29+
**kwargs: Any,
30+
) -> Callable[[T], T]:
31+
def decorator(cls_or_fn: T) -> T:
32+
cls_or_fn.__dataclass_transform__ = {
33+
"eq_default": eq_default,
34+
"order_default": order_default,
35+
"kw_only_default": kw_only_default,
36+
"frozen_default": frozen_default,
37+
"field_specifiers": field_specifiers,
38+
"kwargs": kwargs,
39+
}
40+
return cls_or_fn
41+
42+
return decorator
43+
44+
45+
__all__ = ["ImmutableMeta", "SlottedMeta", "dataclass_transform"]
46+
47+
flatten = chain.from_iterable
48+
_KEYS_NAME: Final = "__immutable_keys__"
49+
_HASH_NAME: Final = "__immutable_hash_value__"
50+
51+
52+
class SlottedMeta(type):
53+
"""Ensure [`__slots__`] are always defined to prevent `__dict__` creation.
54+
55+
[`__slots__`]: https://docs.python.org/3/reference/datamodel.html#object.__slots__
56+
"""
57+
58+
# https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/abc.pyi#L13-L19
59+
# https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/_typeshed/__init__.pyi#L36-L40
60+
# https://github.com/astral-sh/ruff/issues/8353#issuecomment-1786238311
61+
# https://docs.python.org/3/reference/datamodel.html#creating-the-class-object
62+
def __new__(
63+
metacls: type[_typeshed.Self],
64+
cls_name: str,
65+
bases: tuple[type, ...],
66+
namespace: dict[str, Any],
67+
/,
68+
**kwds: Any,
69+
) -> _typeshed.Self:
70+
namespace.setdefault("__slots__", ())
71+
return super().__new__(metacls, cls_name, bases, namespace, **kwds) # type: ignore[no-any-return, misc]
72+
73+
74+
@dataclass_transform(kw_only_default=True, frozen_default=True)
75+
class ImmutableMeta(SlottedMeta):
76+
def __new__(
77+
metacls: type[_typeshed.Self],
78+
cls_name: str,
79+
bases: tuple[type, ...],
80+
namespace: dict[str, Any],
81+
/,
82+
**kwds: Any,
83+
) -> _typeshed.Self:
84+
KEYS, HASH = _KEYS_NAME, _HASH_NAME
85+
getattr_: Callable[..., Seq[str]] = getattr
86+
it_bases = (getattr_(b, KEYS, ()) for b in bases)
87+
it_all = chain(
88+
flatten(it_bases), namespace.get(KEYS, namespace.get("__slots__", ()))
89+
)
90+
namespace[KEYS] = tuple(key for key in it_all if key != HASH)
91+
return super().__new__(metacls, cls_name, bases, namespace, **kwds) # type: ignore[no-any-return, misc]

narwhals/_plan/expr.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,11 @@ def hist(
194194
if bin_count is not None:
195195
msg = "can only provide one of `bin_count` or `bins`"
196196
raise ComputeError(msg)
197-
node = F.HistBins(bins=tuple(bins), include_breakpoint=include_breakpoint)
197+
node = F.Hist.from_bins(bins, include_breakpoint=include_breakpoint)
198198
elif bin_count is not None:
199-
node = F.HistBinCount(
200-
bin_count=bin_count, include_breakpoint=include_breakpoint
201-
)
199+
node = F.Hist.from_bin_count(bin_count, include_breakpoint=include_breakpoint)
202200
else:
203-
node = F.HistBinCount(include_breakpoint=include_breakpoint)
201+
node = F.Hist.from_bin_count(include_breakpoint=include_breakpoint)
204202
return self._with_unary(node)
205203

206204
def log(self, base: float = math.e) -> Self:

narwhals/_plan/expressions/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@
1010
from narwhals._plan.expressions import (
1111
aggregation,
1212
boolean,
13+
categorical,
1314
functions,
15+
lists,
1416
operators,
1517
selectors,
18+
strings,
19+
struct,
20+
temporal,
1621
)
1722
from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr, max, min
1823
from narwhals._plan.expressions.expr import (
@@ -85,10 +90,12 @@
8590
"_ColumnSelection",
8691
"aggregation",
8792
"boolean",
93+
"categorical",
8894
"col",
8995
"cols",
9096
"functions",
9197
"index_columns",
98+
"lists",
9299
"max",
93100
"min",
94101
"named_ir",
@@ -97,4 +104,7 @@
97104
"over",
98105
"over_ordered",
99106
"selectors",
107+
"strings",
108+
"struct",
109+
"temporal",
100110
]

narwhals/_plan/expressions/aggregation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,18 @@ class ArgMin(OrderableAggExpr): ...
5050
class ArgMax(OrderableAggExpr): ...
5151
# fmt: on
5252
class Quantile(AggExpr):
53-
__slots__ = (*AggExpr.__slots__, "interpolation", "quantile")
53+
__slots__ = ("interpolation", "quantile")
5454
quantile: float
5555
interpolation: RollingInterpolationMethod
5656

5757

5858
class Std(AggExpr):
59-
__slots__ = (*AggExpr.__slots__, "ddof")
59+
__slots__ = ("ddof",)
6060
ddof: int
6161

6262

6363
class Var(AggExpr):
64-
__slots__ = (*AggExpr.__slots__, "ddof")
64+
__slots__ = ("ddof",)
6565
ddof: int
6666

6767

narwhals/_plan/expressions/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ class OrderedWindowExpr(
407407
child=("expr", "partition_by", "order_by"),
408408
config=ExprIROptions.renamed("over_ordered"),
409409
):
410-
__slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023
410+
__slots__ = ("order_by", "sort_options")
411411
expr: ExprIR
412412
partition_by: Seq[ExprIR]
413413
order_by: Seq[ExprIR]

0 commit comments

Comments
 (0)