-
Hi. I'm trying to write a nn.Module where the parameters are a sparse coo matrix.
When running Upon examination the parameter sparse_weights is <class 'jax.interpreters.ad.JVPTracer'> instead of DeviceArray. Can someone help me with this? Many thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Regarding the variable being a Regarding the import jax
import jax.numpy as jnp
@jax.vmap
def f(x):
print(f"type(x) = {type(x)}")
print(f"x.shape = {x.shape}")
return x[0]
x = jnp.arange(4)
f(x)
Here Your situation is probably analogous: the solution would be either to pass a higher-dimensional array to your vmapped function, or perhaps remove the Hope that helps – if you need more specific recommendations for fixing your problem, I'd suggest editing your question to add a minimal reproducible example that others can run to see the exact behavior you're seeing. Best of luck! |
Beta Was this translation helpful? Give feedback.
Regarding the variable being a
JVPTracer
when you print it within thevalue_and_grad
transform: this is expected behavior. JAX uses tracers as standins for DeviceArray objects when transforming functions withjit
,vmap
,grad
, and other transforms. Take a look at How to think in JAX for some background on this.Regarding the
IndexError
: it sounds like you're attempting to index a scalar value within thevmap
expression, but it's difficult to tell why because your code snippet doesn't show how you're calling the functions. Here's a simpler example of how that can happen: