Skip to content

Commit 0277a30

Browse files
nstarmanpatrick-kidger
authored andcommitted
perf: code generation
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 9b609e4 commit 0277a30

File tree

4 files changed

+257
-106
lines changed

4 files changed

+257
-106
lines changed

equinox/_module/_flatten.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""Utilities for generating optimized flatten/unflatten functions for Module."""
2+
3+
import dataclasses
4+
import textwrap
5+
from enum import Enum
6+
from typing import Any, Final, Literal, TypeVar
7+
8+
import jax.tree_util as jtu
9+
10+
11+
# Names for generated functions
12+
_FLAT_FUNC_NAME = "<generated_flatten_{0}>"
13+
_FLAT_KEYS_NAME = "<generated_flatten_with_keys_{0}>"
14+
_UNFLAT_NAME = "<generated_unflatten_{0}>"
15+
16+
# Fields of Module wrappers to always include
17+
WRAPPER_FIELD_NAMES: Final = (
18+
"__module__",
19+
"__name__",
20+
"__qualname__",
21+
"__doc__",
22+
"__annotations__",
23+
)
24+
25+
# Indentation for generated code
26+
_INDENT: Final = " " * 4
27+
28+
29+
# Sentinel values for flattening/unflattening
30+
class _Sentinel(Enum):
31+
"""Sentinel values for flattening/unflattening."""
32+
33+
MISSING = "MISSING"
34+
35+
36+
MISSING = _Sentinel.MISSING
37+
38+
# Code template for flattening the wrapper fields
39+
_GETTER = "get({k!r},MISSING)" # get = obj.__dict__.get
40+
_FLAT_WRAPPERS = ", ".join([_GETTER.format(k=k) for k in WRAPPER_FIELD_NAMES])
41+
42+
# Code template for flatten function
43+
_FLAT_CODE = f'''
44+
def {{func_name}}(obj: module_cls) -> {{return_annotation}}:
45+
"""Generated {{func_name}} function for {{qualname}}.
46+
47+
Dynamic fields: {{dynamic_fs}}
48+
Static fields: {{static_fs}}
49+
Wrapper fields: {WRAPPER_FIELD_NAMES}
50+
"""
51+
get = obj.__dict__.get
52+
return (
53+
{{dynamic_vals}},
54+
(
55+
MISSING if len(obj.__dict__) == {{num_fields}} else ({_FLAT_WRAPPERS}),
56+
{{static_vals}}
57+
)
58+
)
59+
'''
60+
61+
# Code template for setting wrapper fields during unflattening
62+
_SET_WRAPPERS: Final = "\n".join(
63+
f"self.__dict__[{name!r}] = waux[{i}]" for i, name in enumerate(WRAPPER_FIELD_NAMES)
64+
)
65+
66+
# Code template for setting dynamic fields during unflattening
67+
_SET_DYNAMIC = """
68+
if data[{i}] is not MISSING:
69+
object.__setattr__(self, {name!r}, data[{i}])
70+
"""[1:-1] # (trim leading and trailing newlines)
71+
72+
# Code template for setting static fields during unflattening
73+
_SET_STATIC = """
74+
if aux[{i}] is not MISSING:
75+
object.__setattr__(self, {name!r}, aux[{i}])
76+
"""[1:-1] # (trim leading and trailing newlines)
77+
78+
# Code template for unflatten function
79+
_UNFLAT_FUNC = f'''
80+
def unflatten(
81+
module_cls: type[T],
82+
aux: {{aux_type}},
83+
data: {{dynamic_type}},
84+
) -> T:
85+
"""Generated unflatten function for {{qualname}}.
86+
87+
Dynamic fields: {{dynamic_fs}}
88+
Static fields: {{static_fs}}
89+
Wrapper fields: {WRAPPER_FIELD_NAMES}
90+
"""
91+
self = object.__new__(module_cls)
92+
# Set fields directly by index
93+
{{setters_dynamic}}
94+
if aux[0] is not MISSING:
95+
waux = aux[0]
96+
{textwrap.indent(_SET_WRAPPERS, _INDENT * 2)}
97+
{{setters_static}}
98+
return self
99+
'''
100+
101+
_NS_BASE = {
102+
"object": object,
103+
"Any": Any,
104+
"tuple": tuple,
105+
"Literal": Literal,
106+
"MISSING": MISSING,
107+
}
108+
109+
110+
def _make_tuple_type(count: int, elt: str, /) -> str:
111+
"""Generate a tuple type annotation string for a given count of elements."""
112+
return "tuple[()]" if count == 0 else f"tuple[{', '.join([elt] * count)}]"
113+
114+
115+
def generate_flatten_functions(cls: type, fields: tuple[dataclasses.Field[Any], ...]):
116+
"""Generate optimized flatten/unflatten functions for a specific field config."""
117+
# Separate dynamic and static fields
118+
_dynamic_fs, _static_fs = [], []
119+
for f in fields:
120+
if f.metadata.get("static", False):
121+
_static_fs.append(f.name)
122+
else:
123+
_dynamic_fs.append(f.name)
124+
dynamic_fs, static_fs = tuple(_dynamic_fs), tuple(_static_fs)
125+
n_dynamic_fs = len(dynamic_fs)
126+
n_static_fs = len(static_fs)
127+
128+
# Extract the generated functions from respective namespaces to
129+
# set the proper module reference.
130+
module_name = getattr(cls, "__module__", cls.__name__)
131+
132+
# -------------------------------------------
133+
# Generate flatten function
134+
135+
# Directly access dynamic fields by name
136+
if n_dynamic_fs == 0:
137+
dynamic_vals = "()"
138+
else:
139+
dynamic_exprs = [_GETTER.format(k=k) for k in dynamic_fs]
140+
dynamic_vals = f"({', '.join(dynamic_exprs)},)"
141+
142+
# For static fields, we need to store their values in aux data
143+
static_exprs = [_GETTER.format(k=k) for k in static_fs]
144+
static_vals = f"{', '.join(static_exprs)}"
145+
146+
# Build return type annotation
147+
dynamic_type = _make_tuple_type(n_dynamic_fs, "Any")
148+
wrapper_type = "Literal[MISSING]|tuple[Any, ...]"
149+
static_type = f"tuple[{wrapper_type}, {', '.join(['Any'] * n_static_fs)}]"
150+
clsname = cls.__qualname__
151+
152+
# Generate flatten function code
153+
flat_code = _FLAT_CODE.format(
154+
func_name="flatten",
155+
return_annotation=f"tuple[{dynamic_type}, {static_type}]",
156+
qualname=clsname,
157+
dynamic_fs=dynamic_fs,
158+
static_fs=static_fs,
159+
num_fields=n_dynamic_fs + n_static_fs,
160+
dynamic_vals=dynamic_vals,
161+
static_vals=static_vals,
162+
)
163+
164+
# make flatten func
165+
flat_ns = _NS_BASE | {"module_cls": cls}
166+
exec(compile(flat_code, _FLAT_FUNC_NAME.format(clsname), "exec"), flat_ns)
167+
flat_fn = flat_ns["flatten"]
168+
flat_fn.__module__ = module_name
169+
object.__setattr__(flat_fn, "__source__", flat_code)
170+
171+
# -------------------------------------------
172+
# Generate flatten_with_keys function
173+
174+
# Generate flatten_with_keys values
175+
if n_dynamic_fs == 0:
176+
dynamic_key_vals = "()"
177+
else:
178+
key_exprs = [
179+
f"(jtu.GetAttrKey({k!r}), {_GETTER.format(k=k)})" for k in dynamic_fs
180+
]
181+
dynamic_key_vals = f"({', '.join(key_exprs)},)"
182+
183+
keys_dynamic_type = _make_tuple_type(n_dynamic_fs, "tuple[jtu.GetAttrKey, str]")
184+
185+
# Generate flatten_with_keys function code
186+
flat_k_code = _FLAT_CODE.format(
187+
func_name="flatten_with_keys",
188+
return_annotation=f"tuple[{keys_dynamic_type}, {static_type}]",
189+
qualname=clsname,
190+
dynamic_fs=dynamic_fs,
191+
static_fs=static_fs,
192+
num_fields=n_dynamic_fs + n_static_fs,
193+
dynamic_vals=dynamic_key_vals,
194+
static_vals=static_vals,
195+
)
196+
197+
# flatten with keys func
198+
flat_k_ns = _NS_BASE | {"jtu": jtu, "module_cls": cls}
199+
exec(compile(flat_k_code, _FLAT_KEYS_NAME.format(clsname), "exec"), flat_k_ns)
200+
flat_k_fn = flat_k_ns["flatten_with_keys"]
201+
flat_k_fn.__module__ = module_name
202+
object.__setattr__(flat_k_fn, "__source__", flat_k_code)
203+
204+
# -------------------------------------------
205+
# Generate unflatten function - directly set fields by index
206+
# Extract types from flatten return type: tuple[dynamic_type, static_type]
207+
208+
# Set dynamic fields by index
209+
unflat_dynamic = [
210+
_SET_DYNAMIC.format(i=i, name=k) for i, k in enumerate(dynamic_fs)
211+
]
212+
# Set static fields by index. Offset by 1 for wrapper auxiliary field.
213+
unflat_aux = [
214+
_SET_STATIC.format(i=i, name=k) for i, k in enumerate(static_fs, start=1)
215+
]
216+
217+
# Generate unflatten function code
218+
unflat_code = _UNFLAT_FUNC.format(
219+
aux_type=static_type,
220+
dynamic_type=dynamic_type,
221+
qualname=clsname,
222+
dynamic_fs=dynamic_fs,
223+
static_fs=static_fs,
224+
setters_dynamic=textwrap.indent("\n".join(unflat_dynamic), _INDENT),
225+
setters_static=textwrap.indent("\n".join(unflat_aux), _INDENT),
226+
)
227+
228+
# Namespace for unflatten function (takes module_cls as parameter)
229+
unflat_ns = _NS_BASE | {"type": type, "T": TypeVar("T")}
230+
# unflatten
231+
exec(compile(unflat_code, _UNFLAT_NAME.format(clsname), "exec"), unflat_ns)
232+
unflat_fn = unflat_ns["unflatten"]
233+
unflat_fn.__module__ = module_name
234+
object.__setattr__(unflat_fn, "__source__", unflat_code)
235+
236+
# -------------------------------------------
237+
238+
return flat_fn, flat_k_fn, unflat_fn

0 commit comments

Comments
 (0)