PRNGKey handling with sharding + jit #24479
Unanswered
giovannicemin
asked this question in
Ideas
Replies: 1 comment 1 reply
-
Wrapping @partial(jax.vmap, in_axes=(0, None, None))
def f(key, a, b):
k1, k2 = jax.random.split(key)
return random_choice(k1, a), random_choice(k2, b) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone,
I recently ran into some issues handling random keys within functions with sharded inputs.The main challenge is that jax.random.split only accepts single keys, while sharding passes batched data to the function. I came up with a workaround that I'd like to share:
If this solution works for you, great!
If you have a better approach, I'd love to hear your suggestions.
A minimal example is:
This throws the following error: ValueError: split accepts a single key, but was given a key array of shape (8,) != (). Use jax.vmap for batching.
The solution (or workaround) that I found is to replace
jax.random.split
with:in this way, the above code spits out:
Beta Was this translation helpful? Give feedback.
All reactions