Skip to content

Commit 97bbc37

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[dlpack] Support more DLPack dtypes now that we target DLPack 1.1
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor TensorFlow currently support importing DLPack arrays with the new dtypes. PiperOrigin-RevId: 736882904
1 parent c9ac82c commit 97bbc37

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

tests/array_interoperability_test.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from jax.sharding import PartitionSpec as P
2323
from jax._src import config
2424
from jax._src import test_util as jtu
25+
from jax._src.lib import version as jaxlib_version
2526

2627
import numpy as np
2728

@@ -42,6 +43,27 @@
4243

4344
dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__)
4445

46+
# These dtypes are not supported by neither NumPy nor TensorFlow, therefore
47+
# we list them separately from ``jax.dlpack.SUPPORTED_DTYPES``.
48+
extra_dlpack_dtypes = []
49+
if jaxlib_version >= (0, 5, 3):
50+
extra_dlpack_dtypes = [
51+
jnp.float8_e4m3b11fnuz,
52+
jnp.float8_e4m3fn,
53+
jnp.float8_e4m3fnuz,
54+
jnp.float8_e5m2,
55+
jnp.float8_e5m2fnuz,
56+
] + [
57+
dtype
58+
for name in [
59+
"float4_e2m1fn",
60+
"float8_e3m4",
61+
"float8_e4m3",
62+
"float8_e8m0fnu",
63+
]
64+
if (dtype := getattr(jnp, name, None))
65+
]
66+
4567
numpy_dtypes = sorted(
4668
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
4769
key=lambda x: x.__name__)
@@ -63,14 +85,16 @@ def setUp(self):
6385
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
6486

6587
@jtu.sample_product(
66-
shape=all_shapes,
67-
dtype=dlpack_dtypes,
68-
copy=[False, True, None],
69-
use_stream=[False, True],
88+
shape=all_shapes,
89+
dtype=dlpack_dtypes + extra_dlpack_dtypes,
90+
copy=[False, True, None],
91+
use_stream=[False, True],
7092
)
7193
@jtu.run_on_devices("gpu")
72-
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
73-
category=DeprecationWarning)
94+
@jtu.ignore_warning(
95+
message="Calling from_dlpack with a DLPack tensor",
96+
category=DeprecationWarning,
97+
)
7498
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
7599
rng = jtu.rand_default(self.rng())
76100
np = rng(shape, dtype)

0 commit comments

Comments
 (0)