@@ -181,7 +181,7 @@ def _error_method_assignment(self, value: object, /) -> None:
181181 raise ValueError (MSG_METHOD_IN_INIT )
182182
183183
184- _transform_types = {
184+ _transform_types : set [ type ] = {
185185 type (transform (lambda x : x ))
186186 for transform in (
187187 jax .jit ,
@@ -208,7 +208,7 @@ def _is_array_like(x: object, /) -> None:
208208 raise _JaxTransformException
209209
210210
211- MSG_JAX_XFM_FUNC : Final = """
211+ _MSG_JAX_XFM_FUNC : Final = """
212212Possibly assigning a JAX-transformed callable as an attribute on
213213{0}.{1}. This will not have any of its parameters updated.
214214
@@ -263,7 +263,7 @@ def _warn_jax_transformed_function(cls: "_ModuleMeta", x: object) -> None:
263263 jtu .tree_map (_is_array_like , x )
264264 except _JaxTransformException :
265265 warnings .warn (
266- MSG_JAX_XFM_FUNC .format (cls .__module__ , cls .__qualname__ ),
266+ _MSG_JAX_XFM_FUNC .format (cls .__module__ , cls .__qualname__ ),
267267 stacklevel = 3 ,
268268 )
269269 break
@@ -319,6 +319,44 @@ def __init__(self, field1, field2)
319319""" [1 :]
320320
321321
322+ _MSG_FIELD_INIT_FALSE : Final = """
323+ Using `field(init=False)` on `equinox.Module` can lead to surprising behaviour
324+ when used around `jax.grad`. In the following example, observe how JAX computes
325+ gradients with respect to the `.len` attribute (which is a PyTree leaf passed
326+ across the `jax.grad` boundary) and that there are no gradients with respect to
327+ `.a` or `.b`:
328+
329+ ```
330+ import equinox as eqx
331+ import jax
332+ import jax.numpy as jnp
333+
334+ class Foo(eqx.Module):
335+ a: jax.Array
336+ b: jax.Array
337+ len: jax.Array = eqx.field(init=False)
338+
339+ def __post_init__(self):
340+ self.len = jnp.sqrt(self.a**2 + self.b**2)
341+
342+ def __call__(self, x):
343+ return self.len * x
344+
345+ @jax.jit
346+ @jax.grad
347+ def call(module, x):
348+ return module(x)
349+
350+ grads = call(Foo(jnp.array(3.0), jnp.array(4.0)), 5)
351+ # Foo(
352+ # a=Array(0., dtype=float32, weak_type=True),
353+ # b=Array(0., dtype=float32, weak_type=True),
354+ # len=Array(5., dtype=float32, weak_type=True)
355+ # )
356+ ```
357+ """ [1 :]
358+
359+
322360# This deliberately does not pass `frozen_default=True`, as that clashes with custom
323361# `__init__` methods.
324362@dataclass_transform (field_specifiers = (dataclasses .field , field ))
@@ -471,44 +509,7 @@ def __call__(cls, *args: object, **kwargs: object): # noqa: N805
471509 )
472510 if not f .init :
473511 if any (jtu .tree_map (is_inexact_array_like , jtu .tree_leaves (val ))):
474- warnings .warn (
475- "Using `field(init=False)` on `equinox.Module` can lead to "
476- "surprising behaviour when used around `jax.grad`. In the "
477- "following example, observe how JAX computes gradients with "
478- "respect to the `.len` attribute (which is a PyTree leaf "
479- "passed across the `jax.grad` boundary) and that there are no "
480- "gradients with respect to `.a` or `.b`:\n "
481- "\n "
482- "```\n "
483- "import equinox as eqx\n "
484- "import jax\n "
485- "import jax.numpy as jnp\n "
486- "\n "
487- "class Foo(eqx.Module):\n "
488- " a: jax.Array\n "
489- " b: jax.Array\n "
490- " len: jax.Array = eqx.field(init=False)\n "
491- "\n "
492- " def __post_init__(self):\n "
493- " self.len = jnp.sqrt(self.a**2 + self.b**2)\n "
494- "\n "
495- " def __call__(self, x):\n "
496- " return self.len * x\n "
497- "\n "
498- "@jax.jit\n "
499- "@jax.grad\n "
500- "def call(module, x):\n "
501- " return module(x)\n "
502- "\n "
503- "grads = call(Foo(jnp.array(3.0), jnp.array(4.0)), 5)\n "
504- "# Foo(\n "
505- "# a=Array(0., dtype=float32, weak_type=True),\n "
506- "# b=Array(0., dtype=float32, weak_type=True),\n "
507- "# len=Array(5., dtype=float32, weak_type=True)\n "
508- "# )\n "
509- "```" ,
510- stacklevel = 2 ,
511- )
512+ warnings .warn (_MSG_FIELD_INIT_FALSE , stacklevel = 2 )
512513
513514 for parent_cls in cls .__mro__ :
514515 try :
0 commit comments