Skip to content
Discussion options

You must be logged in to vote

You can do this by generating standard normal numbers with jax.random.normal, then multiply by sigma, and add mu:

from jax import random

mu = 1.0
sigma = 3.0
key = random.PRNGKey(1701)
x = mu + sigma * random.normal(key, shape=(10000,))

print(x.mean(), x.std())
# 0.9974507 2.9847035

(note that sigma here represents the standard deviation rather than the variance; I believe this is the same convention used in torch.normal).

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@Dsantra92
Comment options

@cisprague
Comment options

@jakevdp
Comment options

@cisprague
Comment options

@jakevdp
Comment options

Answer selected by Dsantra92
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants