Conversions between JAX PRNGKey, integer seed and NumPy RandomState/Generator #8446
Unanswered
ethanluoyc
asked this question in
Q&A
Replies: 1 comment 3 replies
-
I don't know of any definitive answer regarding this. For numpy in particular, the new from jax import random
import numpy as np
key = random.PRNGKey(0)
rng = np.random.default_rng(np.asarray(key)) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Is there a recommended approach for conversions back and forth JAX PRNGKey and other random number generators?
Frequently, the user needs to use random number generators in different libraries (NumPy, or other libraries where they expect an integer seed). Imagine the case that I have a function that takes a JAX PRNGKey. Inside the function, I would need to split the JAX PRNGKey and use it to seed other libraries (with NumPy RandomState, integer seeds). However, since the other libraries may not support JAX PRNGKey, I would need to do some conversions.
For example, here is my approach to converting PRNGKeys to integer seeds for seeding OpenAI's Gym environment. I wrote
as opposed to
Since I heard the later approach may suffer some issues with correlations in the seeds, but I am not sure if this would be true for JAX.
Some libraries expect NumPy RandomState/Generator to be used, in those cases I suppose it's safe to simply use the PRNGKey to seed the RandomState, so I suppose it's ok to just do
I would love to know the approach I am using above is sound, or is there an even safer approach to do these conversions. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions