-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
Hello,
Ive encountered a problem when trying to create a Gauss-Newton like matrix in terms of lineax.FunctionLinearOperators. My test_func takes and returns a pytree. The final line where I try to compute j2 fails with an assertion error.
Am I doing something wrong or is what im trying to do just not possible?
import lineax as lx
import equinox as eqx
import jax
class my_value(eqx.Module):
x: None
y: None
def test_func(x):
return jax.tree.map(lambda x: jnp.abs(x)**2, x)
x = my_value(jnp.ones((3,1)), jnp.ones((3,1)))
jac = lx.FunctionLinearOperator(jax.jacobian(test_func), jax.eval_shape(lambda: x))
j2 = jac.transpose() @ jac
site-packages/jax/_src/lax/lax.py:7400, in _select_transpose_rule(t, which, *cases)
7399 def _select_transpose_rule(t, which, *cases):
->7400 assert not ad.is_undefined_primal(which)
7401 if type(t) is ad_util.Zero:
7402 return [None] + [ad_util.Zero(c.aval) if ad.is_undefined_primal(c) else None
7403 for c in cases]
AssertionError:
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels