Skip to content

Commit 9b609e4

Browse files
nstarmanpatrick-kidger
authored andcommitted
refactor: move field(init=False) warning message
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 20414a7 commit 9b609e4

File tree

1 file changed

+42
-41
lines changed

1 file changed

+42
-41
lines changed

equinox/_module/_module.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _error_method_assignment(self, value: object, /) -> None:
181181
raise ValueError(MSG_METHOD_IN_INIT)
182182

183183

184-
_transform_types = {
184+
_transform_types: set[type] = {
185185
type(transform(lambda x: x))
186186
for transform in (
187187
jax.jit,
@@ -208,7 +208,7 @@ def _is_array_like(x: object, /) -> None:
208208
raise _JaxTransformException
209209

210210

211-
MSG_JAX_XFM_FUNC: Final = """
211+
_MSG_JAX_XFM_FUNC: Final = """
212212
Possibly assigning a JAX-transformed callable as an attribute on
213213
{0}.{1}. This will not have any of its parameters updated.
214214
@@ -263,7 +263,7 @@ def _warn_jax_transformed_function(cls: "_ModuleMeta", x: object) -> None:
263263
jtu.tree_map(_is_array_like, x)
264264
except _JaxTransformException:
265265
warnings.warn(
266-
MSG_JAX_XFM_FUNC.format(cls.__module__, cls.__qualname__),
266+
_MSG_JAX_XFM_FUNC.format(cls.__module__, cls.__qualname__),
267267
stacklevel=3,
268268
)
269269
break
@@ -319,6 +319,44 @@ def __init__(self, field1, field2)
319319
"""[1:]
320320

321321

322+
_MSG_FIELD_INIT_FALSE: Final = """
323+
Using `field(init=False)` on `equinox.Module` can lead to surprising behaviour
324+
when used around `jax.grad`. In the following example, observe how JAX computes
325+
gradients with respect to the `.len` attribute (which is a PyTree leaf passed
326+
across the `jax.grad` boundary) and that there are no gradients with respect to
327+
`.a` or `.b`:
328+
329+
```
330+
import equinox as eqx
331+
import jax
332+
import jax.numpy as jnp
333+
334+
class Foo(eqx.Module):
335+
a: jax.Array
336+
b: jax.Array
337+
len: jax.Array = eqx.field(init=False)
338+
339+
def __post_init__(self):
340+
self.len = jnp.sqrt(self.a**2 + self.b**2)
341+
342+
def __call__(self, x):
343+
return self.len * x
344+
345+
@jax.jit
346+
@jax.grad
347+
def call(module, x):
348+
return module(x)
349+
350+
grads = call(Foo(jnp.array(3.0), jnp.array(4.0)), 5)
351+
# Foo(
352+
# a=Array(0., dtype=float32, weak_type=True),
353+
# b=Array(0., dtype=float32, weak_type=True),
354+
# len=Array(5., dtype=float32, weak_type=True)
355+
# )
356+
```
357+
"""[1:]
358+
359+
322360
# This deliberately does not pass `frozen_default=True`, as that clashes with custom
323361
# `__init__` methods.
324362
@dataclass_transform(field_specifiers=(dataclasses.field, field))
@@ -471,44 +509,7 @@ def __call__(cls, *args: object, **kwargs: object): # noqa: N805
471509
)
472510
if not f.init:
473511
if any(jtu.tree_map(is_inexact_array_like, jtu.tree_leaves(val))):
474-
warnings.warn(
475-
"Using `field(init=False)` on `equinox.Module` can lead to "
476-
"surprising behaviour when used around `jax.grad`. In the "
477-
"following example, observe how JAX computes gradients with "
478-
"respect to the `.len` attribute (which is a PyTree leaf "
479-
"passed across the `jax.grad` boundary) and that there are no "
480-
"gradients with respect to `.a` or `.b`:\n"
481-
"\n"
482-
"```\n"
483-
"import equinox as eqx\n"
484-
"import jax\n"
485-
"import jax.numpy as jnp\n"
486-
"\n"
487-
"class Foo(eqx.Module):\n"
488-
" a: jax.Array\n"
489-
" b: jax.Array\n"
490-
" len: jax.Array = eqx.field(init=False)\n"
491-
"\n"
492-
" def __post_init__(self):\n"
493-
" self.len = jnp.sqrt(self.a**2 + self.b**2)\n"
494-
"\n"
495-
" def __call__(self, x):\n"
496-
" return self.len * x\n"
497-
"\n"
498-
"@jax.jit\n"
499-
"@jax.grad\n"
500-
"def call(module, x):\n"
501-
" return module(x)\n"
502-
"\n"
503-
"grads = call(Foo(jnp.array(3.0), jnp.array(4.0)), 5)\n"
504-
"# Foo(\n"
505-
"# a=Array(0., dtype=float32, weak_type=True),\n"
506-
"# b=Array(0., dtype=float32, weak_type=True),\n"
507-
"# len=Array(5., dtype=float32, weak_type=True)\n"
508-
"# )\n"
509-
"```",
510-
stacklevel=2,
511-
)
512+
warnings.warn(_MSG_FIELD_INIT_FALSE, stacklevel=2)
512513

513514
for parent_cls in cls.__mro__:
514515
try:

0 commit comments

Comments
 (0)