diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index faae2d9b..531200e4 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -274,12 +274,18 @@ for A in matrices: print(A) ``` -One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in +To get a one-dimensional array of normal random draws, we can either use `(len, )` for the shape, as in ```{code-cell} ipython3 random.normal(key, (5, )) ``` +or simply use `5` as the shape argument: + +```{code-cell} ipython3 +random.normal(key, 5) +``` + ## JIT compilation The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear