Skip to content
Discussion options

You must be logged in to vote

Hi – I tried to reproduce your issue based on the description, but I can't. Here's the code I wrote:

import jax

class A:
  def f(self, x, y):
    return x + y
  vf = jax.vmap(f, in_axes=(None, None, 0))

class B(A):
  pass

A().vf(1, jax.numpy.arange(4)) # No error
B().vf(1, jax.numpy.arange(4)) # No error

Can you please update your question with a minimal reproducible example that recreates the error you're seeing? Thanks!

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@GarfieldGa
Comment options

@jakevdp
Comment options

Answer selected by GarfieldGa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants