Replies: 2 comments
-
Beta Was this translation helpful? Give feedback.
0 replies
-
This seems to be the closest idea import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
m = 100
data = {k: jax.random.normal(key, (m, )) for k in ['x', 'y', 'z']}
data_ = {k: jax.random.normal(key, (m, )) for k in ['x', 'y', 'z']}
def g(x, x_):
x = jnp.atleast_1d(x)
return jnp.exp(- jnp.abs(x[:, None] - x_[None, :]))
def f(data, data_):
out = 0
for k in data.keys():
out += g(data[k], data_[k])
return out
res = f(data, data_)
ff = jax.vmap(f, in_axes=({k: 0 for k in data.keys()}, {k: None for k in data.keys()}))
res2 = ff(data, data_).squeeze()
assert jnp.allclose(res, res2) Ends up with an extra dimension but probably unavoidable and fine. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a good pattern for using vmap over a dict of arrays? Something like this
Backstory is this is like a multi-dtype situation (pre-embedding ints and floats) so I've gone down the route of using dicts-as-dataframes.
I feel like I've done something like this before but forgotten if the pattern is possible or not.
Beta Was this translation helpful? Give feedback.
All reactions