Jax vmap - ValueError: vmap must have at least one non-None value in in_axes #21380
-
I have the above code defined in class A, and I have a class B inherited from A, i.e., class B(A). The function forward and v_forward are not overwritten in B. Then when I tried to call self.v_forward(arg1, arg2), it gives an error saying Does anyone know how this comes, since I do have a 0 in my in_axies tuple. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi – I tried to reproduce your issue based on the description, but I can't. Here's the code I wrote: import jax
class A:
def f(self, x, y):
return x + y
vf = jax.vmap(f, in_axes=(None, None, 0))
class B(A):
pass
A().vf(1, jax.numpy.arange(4)) # No error
B().vf(1, jax.numpy.arange(4)) # No error Can you please update your question with a minimal reproducible example that recreates the error you're seeing? Thanks! |
Beta Was this translation helpful? Give feedback.
Hi – I tried to reproduce your issue based on the description, but I can't. Here's the code I wrote:
Can you please update your question with a minimal reproducible example that recreates the error you're seeing? Thanks!