-
Hello, JAX has the fantastic feature of transparently taking gradients through Python dicts. def f(inputs: dict[str, jax.Array]):
x, y = inputs['x'], inputs['y']
return x*x + jnp.sin(y)
features = dict(x=jnp.array(1.), y=jnp.array(2.))
jax.grad(f)(features) returns a dictionary with gradient components named after the variables, brilliant:
QuestionGiven a function If not, how would you implement it? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 3 replies
-
In the spirit of Cunningham's law, here's what I cobbled together to solve the above problem: import jax
import jax.numpy as jnp
from typing import Callable
def grad_wrt_key(f: Callable, wrt: list[str]):
"""Given a callable f that takes a dictionary of jax arrays as input,
return a callable that evaluates the gradient of f w.r.t. one or more of
the arrays in the dictionary (specified by key).
"""
def grad_wrt_impl(inputs: dict[str, jax.Array]):
in_vars = list(inputs.keys())
argnums = [in_vars.index(wrt_var) for wrt_var in wrt]
def f_with_positionals(*args):
args_as_dict = dict(zip(in_vars, args))
return f(args_as_dict)
grads = jax.grad(f_with_positionals, argnums)(*inputs.values())
return dict(zip(wrt, grads))
return grad_wrt_impl usable as: def f(inputs: dict[str, jax.Array]):
x, y = inputs["x"], inputs["y"]
return x * x + jnp.sin(y)
df = grad_wrt_key(f, wrt=["x"])
features = {'x': jnp.array(1.0), 'y': jnp.array(2.0)}
print(df(features))
Besides the general clunkiness of |
Beta Was this translation helpful? Give feedback.
-
I wrote a slightly more general version of this but I think it's on a hard drive that currently doesn't have a home. It does (almost) the same thing but for general pytrees. import jax
import jax.numpy as jnp
def masked_grad(f, mask):
flat_mask, mask_structure = jax.tree_util.tree_flatten(mask)
def flat_f(diff_args, nondiff_args):
diff_iter = iter(diff_args)
nondiff_iter = iter(nondiff_args)
combined_args = [
next(diff_iter if m else nondiff_iter)
for m in flat_mask
]
unflattened_args = mask_structure.unflatten(combined_args)
return f(*unflattened_args)
flat_grad = jax.grad(flat_f)
def grad_fn(*args):
flat_args, arg_structure = jax.tree_util.tree_flatten(args)
assert arg_structure == mask_structure
diff_args = []
nondiff_args = []
arg_it = iter(flat_args)
for m in flat_mask:
if m:
diff_args.append(next(arg_it))
else:
nondiff_args.append(next(arg_it))
grads = iter(flat_grad(diff_args, nondiff_args))
# What to return here probably depends on what you want to do with the grads
# Using None won't play nice with a lot of stuff that uses pytrees
# Could use float0, but optax doesn't like those
placeholder = jnp.zeros(1)
flat_grads_mixed = [next(grads) if m else placeholder for m in flat_mask]
return arg_structure.unflatten(flat_grads_mixed)
return grad_fn
def f(args):
return args['x'] * args['y']
mask = ({'x': True, 'y': False},) # This has to be a tuple
grad_fn = masked_grad(f, mask)
print(grad_fn({'x': 1., 'y': 2.}))
# Output:
# ({'x': Array(2., dtype=float32, weak_type=True), 'y': Array([0.], dtype=float32)},) Drawbacks (fixable):
|
Beta Was this translation helpful? Give feedback.
-
There is no easy built-in way to do this, but generalizing I think the solutions suggested in the other answers here are probably the best option in the current version of JAX. |
Beta Was this translation helpful? Give feedback.
-
How about using equinox filter system? https://docs.kidger.site/equinox/api/filtering/partition-combine/ |
Beta Was this translation helpful? Give feedback.
There is no easy built-in way to do this, but generalizing
argnums
to handle arbitrary pytrees is something that's been frequently discussed. See #3875, #10614, and references within.I think the solutions suggested in the other answers here are probably the best option in the current version of JAX.