Skip to content

Commit d5c00bc

Browse files
committed
feat(expr-ir): Start sketching out (de)serde
Very rough, but identifies the edge cases at least Child of (#2572)
1 parent 2b9dbf0 commit d5c00bc

File tree

2 files changed

+221
-4
lines changed

2 files changed

+221
-4
lines changed

narwhals/_plan/options.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import enum
44
from itertools import repeat
5-
from typing import TYPE_CHECKING, Literal
5+
from typing import TYPE_CHECKING, Any, Literal
66

77
from narwhals._plan._immutable import Immutable
8+
from narwhals._utils import qualified_type_name
89

910
if TYPE_CHECKING:
10-
from collections.abc import Iterable, Sequence
11+
from collections.abc import Iterable, Mapping, Sequence
1112

1213
import pyarrow.acero
1314
import pyarrow.compute as pc
@@ -84,6 +85,25 @@ class FunctionOptions(Immutable):
8485
def __str__(self) -> str:
8586
return f"{type(self).__name__}(flags='{self.flags}')"
8687

88+
def to_dict(self, *, qualify_type_name: bool = False) -> dict[str, dict[str, int]]:
89+
name = qualified_type_name(self) if qualify_type_name else type(self).__name__
90+
return {name: {"flags": self.flags.value}}
91+
92+
@classmethod
93+
def from_dict(cls, mapping: Mapping[str, Any], /) -> FunctionOptions:
94+
flags = (mapping.get(cls.__name__) or mapping[qualified_type_name(cls)])["flags"]
95+
return cls.from_int(int(flags))
96+
97+
@classmethod
98+
def from_int(cls, value: int, /) -> FunctionOptions:
99+
return cls.from_flags(FunctionFlags(value))
100+
101+
@classmethod
102+
def from_flags(cls, flags: FunctionFlags, /) -> FunctionOptions:
103+
obj = FunctionOptions.__new__(FunctionOptions)
104+
object.__setattr__(obj, "flags", cls._ensure_valid_flags(flags))
105+
return obj
106+
87107
def is_elementwise(self) -> bool:
88108
return self.flags.is_elementwise()
89109

@@ -99,12 +119,16 @@ def is_row_separable(self) -> bool:
99119
def is_input_wildcard_expansion(self) -> bool:
100120
return self.flags.is_input_wildcard_expansion()
101121

102-
def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions:
122+
@staticmethod
123+
def _ensure_valid_flags(flags: FunctionFlags, /) -> FunctionFlags:
103124
if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags:
104125
msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive."
105126
raise TypeError(msg)
127+
return flags
128+
129+
def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions:
106130
obj = FunctionOptions.__new__(FunctionOptions)
107-
object.__setattr__(obj, "flags", self.flags | flags)
131+
object.__setattr__(obj, "flags", self.flags | self._ensure_valid_flags(flags))
108132
return obj
109133

110134
def with_elementwise(self) -> FunctionOptions:

narwhals/_plan/serde.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""(De)serialization support for `ExprIR` & `Immutable`.
2+
3+
Planned for use by:
4+
- `Expr.meta.serialize`
5+
- `Expr.deserialize`
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import re
11+
from importlib import import_module
12+
from typing import TYPE_CHECKING, Any, Literal, Union, overload
13+
14+
from narwhals._plan import expressions as ir
15+
from narwhals._plan._immutable import Immutable
16+
from narwhals._plan.options import FunctionOptions
17+
from narwhals._plan.series import Series
18+
from narwhals._utils import Version, isinstance_or_issubclass, qualified_type_name
19+
from narwhals.dtypes import DType, _is_into_dtype, _validate_into_dtype
20+
21+
if TYPE_CHECKING:
22+
from collections.abc import Iterator, Mapping
23+
24+
from typing_extensions import TypeAlias
25+
26+
from narwhals._plan.expr import Expr
27+
from narwhals.typing import FileSource
28+
29+
__all__ = ["deserialize", "from_dict", "serialize", "to_dict"]
30+
31+
dtypes = Version.MAIN.dtypes
32+
33+
JSONLiteral: TypeAlias = (
34+
"Union[str, int, float, list[JSONLiteral], Mapping[str, JSONLiteral], None]"
35+
)
36+
SerializationFormat: TypeAlias = Literal["binary", "json"]
37+
38+
# e.g. `tuple[tuple[...]]` is not allowed
39+
SUPPORTED_NESTED = (str, int, float, type(None), Immutable, DType, type)
40+
# NOTE: Intentionally omits `list`, `dict`
41+
SUPPORTED = (*SUPPORTED_NESTED, Series, tuple)
42+
43+
SELECTORS_ONLY = frozenset, re.Pattern
44+
45+
# udf + `.name` functions
46+
UNLIKELY = (callable,)
47+
48+
# Not sure how to handle
49+
PLANNED = SELECTORS_ONLY
50+
51+
UNSUPPORTED_DTYPES = (
52+
dtypes.Datetime,
53+
dtypes.Duration,
54+
dtypes.Enum,
55+
dtypes.Struct,
56+
dtypes.List,
57+
dtypes.Array,
58+
)
59+
60+
61+
def _not_yet_impl_error(
62+
obj: Immutable, field: str, value: Any, origin: Literal["to_dict"]
63+
) -> NotImplementedError:
64+
msg = (
65+
f"Found an expected type in {field!r}: {qualified_type_name(value)!r}\n"
66+
f"but serde (`{origin}`) support has not yet been implemented.\n\n{obj!s}"
67+
)
68+
return NotImplementedError(msg)
69+
70+
71+
def _unrecognized_type_error(obj: Immutable, field: str, value: Any) -> TypeError:
72+
msg = f"Found an unrecognized type in {field!r}: {qualified_type_name(value)!r}\n\n{obj!s}"
73+
return TypeError(msg)
74+
75+
76+
def _to_list_iter(
77+
field: str, obj: tuple[object, ...], /, *, qualify_type_name: bool, owner: Immutable
78+
) -> Iterator[JSONLiteral]:
79+
for element in obj:
80+
if isinstance(element, SUPPORTED_NESTED):
81+
yield _to_values(
82+
field, element, qualify_type_name=qualify_type_name, owner=owner
83+
)
84+
elif isinstance(element, PLANNED):
85+
raise _not_yet_impl_error(owner, field, element, "to_dict")
86+
else:
87+
raise _unrecognized_type_error(owner, field, element)
88+
89+
90+
def _to_values(
91+
field: str, v: object, /, *, qualify_type_name: bool, owner: Immutable
92+
) -> JSONLiteral:
93+
# Handled by json, not nested
94+
if isinstance(v, (str, int, float, type(None))):
95+
return v
96+
# Literal / is_in
97+
if isinstance(v, Series):
98+
return v.to_list()
99+
# Primary cases
100+
if isinstance(v, Immutable):
101+
if isinstance(v, FunctionOptions):
102+
return v.to_dict(qualify_type_name=qualify_type_name)
103+
return _to_dict(v, qualify_type_name=qualify_type_name)
104+
# Primary nesting case
105+
if isinstance(v, (tuple,)):
106+
return list(
107+
_to_list_iter(field, v, qualify_type_name=qualify_type_name, owner=owner)
108+
)
109+
if isinstance_or_issubclass(v, DType):
110+
return _dtype_to_path_str(v)
111+
raise _not_yet_impl_error(owner, field, v, "to_dict")
112+
113+
114+
def _to_dict_children_iter(
115+
obj: Immutable, /, *, qualify_type_name: bool
116+
) -> Iterator[tuple[str, JSONLiteral]]:
117+
for k, v in obj.__immutable_items__:
118+
# Handled by json, not nested
119+
if isinstance(v, SUPPORTED):
120+
yield k, _to_values(k, v, qualify_type_name=qualify_type_name, owner=obj)
121+
elif isinstance(v, PLANNED) or callable(v):
122+
raise _not_yet_impl_error(obj, k, v, "to_dict")
123+
else: # FunctionFlags Enum
124+
raise _unrecognized_type_error(obj, k, v)
125+
126+
127+
def _dtype_to_path_str(obj: DType | type[DType]) -> str:
128+
if isinstance_or_issubclass(obj, UNSUPPORTED_DTYPES):
129+
msg = f"DType serialization is not yet supported for {qualified_type_name(obj)!r}"
130+
raise NotImplementedError(msg)
131+
return qualified_type_name(obj)
132+
133+
134+
def _dtype_from_path_str(path_str: str) -> DType:
135+
parts = path_str.rsplit(".", maxsplit=1)
136+
imported_module = import_module(parts[0])
137+
tp = getattr(imported_module, parts[1])
138+
if not _is_into_dtype(tp):
139+
_validate_into_dtype(tp)
140+
return dtypes.Unknown()
141+
return tp.base_type()()
142+
143+
144+
def _to_dict(
145+
obj: ir.ExprIR | Immutable, /, *, qualify_type_name: bool = False
146+
) -> dict[str, dict[str, JSONLiteral]]:
147+
leaf_name = qualified_type_name(obj) if qualify_type_name else type(obj).__name__
148+
return {
149+
leaf_name: dict(_to_dict_children_iter(obj, qualify_type_name=qualify_type_name))
150+
}
151+
152+
153+
def to_dict(expr: Expr | ir.ExprIR, /) -> dict[str, dict[str, JSONLiteral]]:
154+
return _to_dict(expr if isinstance(expr, ir.ExprIR) else expr._ir)
155+
156+
157+
def from_dict(mapping: Mapping[str, Mapping[str, JSONLiteral]], /) -> Any:
158+
msg = "`serde.from_dict` is not yet implemented"
159+
raise NotImplementedError(msg)
160+
161+
162+
@overload
163+
def serialize(
164+
expr: Expr | ir.ExprIR, /, file: None = ..., *, format: Literal["binary"] = ...
165+
) -> bytes: ...
166+
167+
168+
@overload
169+
def serialize(
170+
expr: Expr | ir.ExprIR, /, file: None = ..., *, format: Literal["json"]
171+
) -> str: ...
172+
173+
174+
@overload
175+
def serialize(
176+
expr: Expr | ir.ExprIR, /, file: FileSource, *, format: SerializationFormat = ...
177+
) -> None: ...
178+
179+
180+
def serialize(
181+
expr: Expr | ir.ExprIR,
182+
/,
183+
file: FileSource | None = None,
184+
*,
185+
format: SerializationFormat = "binary",
186+
) -> bytes | str | None:
187+
msg = "`serde.serialize` is not yet implemented"
188+
raise NotImplementedError(msg)
189+
190+
191+
def deserialize(source: FileSource, *, format: SerializationFormat = "binary") -> Expr:
192+
msg = "`serde.deserialize` is not yet implemented"
193+
raise NotImplementedError(msg)

0 commit comments

Comments
 (0)