Overlap in PRNG in Jax #15556
-
I'm working on some simulation, I need to initialize multiple times the same neural network, and I'm using
Now, I need the initializations to be independent, but I'm not sure if this is what I'm getting. Specifically, there is not independence if the sequence of numbers (in my case the neural network parameters) generated with the My question is: is the absence of overlaps guaranteed? If yes, how many parameters I can generate before having overlaps? I have experience with an implementation of the Lehmer random number generator where you can actually generate a certain amount of seeds starting from a master seed and a distance (in terms of the number of calls to the PRNG) is guaranteed between the various seeds. Can this behavior be obtained in Jax? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi - thanks for the question! If you're concerned with random draws being independent, then key1, key = jax.random.split(key)
r1 = jax.random.uniform(key1, shape)
key2, key = jax.random.split(key)
r2 = jax.random.uniform(key2, shape) then Regarding "overlaps", it looks like section 2.2.1 of this paper is relevant: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf It suggests that overlap in streams from JAX's PRNG (which uses threefry by default) should be negligible. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question!
If you're concerned with random draws being independent, then
random.split
is exactly what you want. By construction, this ensures that if you do something like this:then
r1
andr2
will be independent. You can read more about this in JAX's PRNG design doc: https://jax.readthedocs.io/en/latest/jep/263-prng.htmlRegarding "overlaps", it looks like section 2.2.1 of this paper is relevant: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
It suggests that overlap in streams from JAX's PRNG (which uses thre…