Skip to content

Problem with usage of Linear Operators #171

@matillda123

Description

@matillda123

Hello,

Ive encountered a problem when trying to create a Gauss-Newton like matrix in terms of lineax.FunctionLinearOperators. My test_func takes and returns a pytree. The final line where I try to compute j2 fails with an assertion error.
Am I doing something wrong or is what im trying to do just not possible?

import lineax as lx
import equinox as eqx
import jax


class my_value(eqx.Module):
    x: None
    y: None


def test_func(x):
    return jax.tree.map(lambda x: jnp.abs(x)**2, x)



x = my_value(jnp.ones((3,1)), jnp.ones((3,1)))
jac = lx.FunctionLinearOperator(jax.jacobian(test_func), jax.eval_shape(lambda: x))
j2 = jac.transpose() @ jac

site-packages/jax/_src/lax/lax.py:7400, in _select_transpose_rule(t, which, *cases)
7399 def _select_transpose_rule(t, which, *cases):
->7400 assert not ad.is_undefined_primal(which)
7401 if type(t) is ad_util.Zero:
7402 return [None] + [ad_util.Zero(c.aval) if ad.is_undefined_primal(c) else None
7403 for c in cases]
AssertionError:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions