-
How to generate normal distribution of given mean and std. deviation in Jax? To be more explicit, I am looking for a function corresponding to the torch function |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Apr 4, 2021
Replies: 1 comment 5 replies
-
You can do this by generating standard normal numbers with from jax import random
mu = 1.0
sigma = 3.0
key = random.PRNGKey(1701)
x = mu + sigma * random.normal(key, shape=(10000,))
print(x.mean(), x.std())
# 0.9974507 2.9847035 (note that |
Beta Was this translation helpful? Give feedback.
5 replies
Answer selected by
Dsantra92
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can do this by generating standard normal numbers with
jax.random.normal
, then multiply bysigma
, and addmu
:(note that
sigma
here represents the standard deviation rather than the variance; I believe this is the same convention used intorch.normal
).