diff --git a/jax/_src/api.py b/jax/_src/api.py index 5622a6cdfdee..7553da7ae3eb 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -497,12 +497,18 @@ def value_and_grad_f(*args, **kwargs): for leaf in tree_leaves(dyn_args): _check_input_dtype_grad(holomorphic, allow_int, leaf) if not has_aux: - ans, vjp_py = _vjp(f_partial, *dyn_args) + ans, vjp_py = _vjp3(f_partial, *dyn_args) else: - ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) + ans, vjp_py, aux = _vjp3(f_partial, *dyn_args, has_aux=True) _check_scalar(ans) tree_map(partial(_check_output_dtype_grad, holomorphic), ans) - g = vjp_py(lax_internal._one(ans)) + + # g = vjp_py(lax_internal._one(ans)) + + grad_accum = tree_map(lambda x: core.new_ref(lax_internal._zeros(x)), dyn_args) + vjp_py.with_refs(*grad_accum)(lax_internal._one(ans)) + g = tree_map(lambda x: core.freeze(x), grad_accum) + g = g[0] if isinstance(argnums, int) else g if not has_aux: return ans, g