Understanding Performance and Compatibility Issues with jax.numpy Arrays and PRNGs #20491
Replies: 1 comment 4 replies
-
Numpy arrays are silently converted to JAX arrays (and their contents placed on the default device) when passed to a JAX function.
It's hard to say without knowing more about your use-case. Can you show a minimal example of the code you're running? One potential gotcha here is that under JAX's JIT, Python operations (including numpy ops like import jax
import numpy as np
np.random.seed(0)
@jax.jit
def f():
return np.random.normal()
print(f()) # 1.7640524
print(f()) # 1.7640524
print(f()) # 1.7640524
print(f()) # 1.7640524
print(f()) # 1.7640524 Using a compile-time constant will be much faster than generating a new random number at runtime, so this may be why you're seeing such a speedup (but that's only a guess without seeing your code). |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey!
I hope it's cool if I ask a pretty basic question here. I'm kinda new to jax and hit a snag that's making me scratch my head. Sorry for the long post ahead, but I've got a ton of details and questions all tangled up, so bear with me, please.
Context
I've been diving into using jax for neural quantum states (just a bit of context in case anyone's doing that too). Basically, I'm working with a neural network, like an MLP, and feeding it inputs that I get from Monte Carlo sampling (which matters because we're dealing with PRNGs here). To optimize my network, I use jax.grad for differentiation. But here's the thing: at some point, I need to get the gradient (and Laplacian) of the MLP with respect to its inputs.- I'm using jit compilation and vmap for some of my functions too.
Not too far in, I realized my sampling was using numpy for random numbers, which meant my network's inputs were np arrays. And my network's parameters were np arrays too. But then I learned that for jit and autograd magic in jax to work, I gotta use jax.numpy arrays instead. So, I switched my inputs and parameters to jax.numpy arrays.
Problem (?)
I started by switching the parameters to jnp's, and surprisingly, it didn't really affect my code's speed or results...
Next, I had to take care of my inputs. As jax's PRNGs are slower than numpy's fancy PCG (as far as I know), I tried simply to convert my numpy random array to a jax.numpy array using "jnp.array(rng.normal(loc=state.positions, scale=self.scale))". My code follows all the immutable array rules and doesn't throw errors, but this move slowed my code down by like 4x. So, I tried using jax.random number generator properly, but it's still way slower than the numpy stuff - again, around 4x.
This leaves me with a couple of mysteries:
Thanks a ton for sticking with my ramble!
Beta Was this translation helpful? Give feedback.
All reactions