-
Hi Jax team! I am encountering some issue with the new version 0.4.4. I have a class called
Since 0.4.4 the device cannot be instantiated because of three different function abs, conj and dot, the error message is the following:
It makes me think that with the recent update, those three functions became abstract. I can tell that the base class changed if you do:
Before 0.4.4 you obtain:
After 0.4.4 you obtain:
What is the impact of the switch from |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Thanks for the question - this is very curious and I do not know why it's happening. For reference, here's a minimal repro that passes for the old jit (the default in jax 0.4.3 and older) and fails with the new pjit (the default in jax 0.4.4): import abc
import jax
class Base(abc.ABC):
@abc.abstractstaticmethod
def _abs(x): ...
class Derived(Base):
_abs = staticmethod(jax.numpy.abs)
d = Derived() ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-20-b8456e68a3b7>](https://localhost:8080/#) in <module>
9 _abs = staticmethod(jax.numpy.abs)
10
---> 11 d = Derived()
TypeError: Can't instantiate abstract class Derived with abstract methods _abs There must be some detail about how |
Beta Was this translation helpful? Give feedback.
Thanks for the question - this is very curious and I do not know why it's happening. For reference, here's a minimal repro that passes for the old jit (the default in jax 0.4.3 and older) and fails with the new pjit (the default in jax 0.4.4):