Skip to content
Discussion options

You must be logged in to vote

The following should work. I don't know if it's the optimal way to do this regarding the number of host copies. cc @jekbradbury

import jax

def f(key):
  key, subkey = jax.random.split(key)
  return jax.random.uniform(subkey, shape=(3,))

prng_key = jax.random.PRNGKey(41)

# each device gets a different prng_key
device_keys = jax.random.split(prng_key, num=jax.local_device_count())
result = jax.pmap(f)(device_keys)
# note that each core produces a different (3,) uniform random values
print(repr(result))

# ShardedDeviceArray([[0.0898279 , 0.21871209, 0.975634  ],
#                    [0.92832124, 0.5143677 , 0.32663155],
#                    [0.5004319 , 0.42006207, 0.5135381 ],
#        …

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by jaeyoo
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants