Replies: 3 comments 6 replies
-
Can you say more about what you mean by "check trainability"? Tracer subclasses are an internal implementation detail, and |
Beta Was this translation helpful? Give feedback.
-
Thanks for your answer! We use this Is there a way to access the information of what parameters were set trainable by the users (argnums)? |
Beta Was this translation helpful? Give feedback.
-
@mattjj, @patrick-kidger and I recently added symbolic zero support to Using this, if an array doesn't participate in differentiation, then it will enter as a symbolic zero, not take up any memory, and your custom JVP can handle things accordingly from there. Example: from functools import partial
import jax
Zero = jax.custom_derivatives.SymbolicZero
@jax.custom_jvp
def f(x, y): return x * y
@partial(f.defjvp, symbolic_zeros=True)
def fjvp(primals, tangents):
x, y = primals
tx, ty = tangents
if type(tx) is Zero and type(ty) is Zero:
tz = None
elif type(tx) is Zero:
tz = x * ty
elif type(ty) is Zero:
tz = y * tx
else:
tz = x * ty + y * tx
return f(x, y), tz >>> jax.jacfwd(f, argnums=0)(2., 3.)
Array(3., dtype=float32, weak_type=True)
>>> jax.jacfwd(f, argnums=1)(2., 3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.jacfwd(f, argnums=[0, 1])(2., 3.)
(Array(3., dtype=float32, weak_type=True), Array(2., dtype=float32, weak_type=True)) As @jakevdp wrote, traces and tracers are jax-internal implementation details. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jax team! This behaviour has changed in 0.4.4, previously we would have exactly the same trace on the parameters by switching grad and jit but now we do not have access to the same information on the trace.
Is there still a way to access the JVPTrace on the first example? We are using
isinstance(x, JVPTracer)
to check trainability in a complicated workflow.Beta Was this translation helpful? Give feedback.
All reactions