Skip to content

Commit 1d07af1

Browse files
authored
Change jax DeviceArray to ndarray (#591)
* Change DeviceArray to jnp.ndarray * Pump minimum jax version
1 parent 8d69fd3 commit 1d07af1

File tree

3 files changed

+2
-13
lines changed

3 files changed

+2
-13
lines changed

funsor/interpretations.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,6 @@ def interpret(self, cls, *args):
5353
@staticmethod
5454
def make_hash_key(cls, *args):
5555
backend = get_backend()
56-
if backend == "jax":
57-
# JAX DeviceArray has .__hash__ method but raise the unhashable error there.
58-
from jax.interpreters.xla import DeviceArray
59-
60-
return tuple(
61-
id(arg)
62-
if isinstance(arg, DeviceArray) or not isinstance(arg, Hashable)
63-
else arg
64-
for arg in args
65-
)
6656
if backend == "torch":
6757
# Avoid "ImportError: sys.meta_path is None" on shutdown.
6858
from torch import Tensor

funsor/jax/ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as onp
1010
from jax import lax
1111
from jax.core import Tracer
12-
from jax.interpreters.xla import DeviceArray
1312
from jax.scipy.linalg import cho_solve, solve_triangular
1413
from jax.scipy.special import expit, gammaln, logsumexp
1514

@@ -19,7 +18,7 @@
1918
# Register Ops
2019
################################################################################
2120

22-
array = (onp.generic, onp.ndarray, DeviceArray, Tracer)
21+
array = (onp.generic, onp.ndarray, np.ndarray, Tracer)
2322
ops.atanh.register(array)(np.arctanh)
2423
ops.clamp.register(array)(np.clip)
2524
ops.exp.register(array)(np.exp)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
],
4040
extras_require={
4141
"torch": ["pyro-ppl>=1.8.0", "torch>=1.11.0"],
42-
"jax": ["numpyro>=0.7.0", "jax>=0.2.13", "jaxlib>=0.1.65"],
42+
"jax": ["numpyro>=0.7.0", "jax>=0.2.21", "jaxlib>=0.1.71"],
4343
"test": [
4444
"black",
4545
"flake8",

0 commit comments

Comments
 (0)