Skip to content
Discussion options

You must be logged in to vote

Thanks for the question - this is very curious and I do not know why it's happening. For reference, here's a minimal repro that passes for the old jit (the default in jax 0.4.3 and older) and fails with the new pjit (the default in jax 0.4.4):

import abc
import jax

class Base(abc.ABC):
  @abc.abstractstaticmethod
  def _abs(x): ...

class Derived(Base):
  _abs = staticmethod(jax.numpy.abs)

d = Derived()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-20-b8456e68a3b7>](https://localhost:8080/#) in <module>
      9   _abs = staticmethod(jax.numpy.abs)
     10 
---> 11 d …

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@jakevdp
Comment options

@jakevdp
Comment options

@rmoyard
Comment options

Answer selected by rmoyard
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants