Hi All,
Is there a preferred way to do compute the jacobian of a non-holomorphic function within equinox? Something akin to a filter_jacrev_nonholomorphic?
The error message itself just states to use jax.vjp directly ,
TypeError: jacrev requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex64. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.
So, I'd assume the best way to do this is use eqx.partition to get the params and static of a PyTree, then compute the real and imag parts separately via jax.vjp (and not eqx.filter_vjp) and then return vjp_fn as a new partial via functools.partial or eqx.Partial, perhaps?