Skip to content

Commit 478ea0d

Browse files
committed
Allow 64-bit output types from ffi_call regardless of enable_x64 flag.
1 parent 6892e62 commit 478ea0d

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

jax/_src/extend/ffi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from jax._src import dispatch
2828
from jax._src import effects
2929
from jax._src import util
30-
from jax._src.callback import _check_shape_dtype, callback_batching_rule
30+
from jax._src.callback import callback_batching_rule
3131
from jax._src.interpreters import ad
3232
from jax._src.interpreters import batching
3333
from jax._src.interpreters import mlir
@@ -209,11 +209,14 @@ def _lowering(
209209

210210
def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]:
211211
avals: list[core.AbstractValue] = []
212-
for result in results:
212+
for idx, result in enumerate(results):
213213
if isinstance(result, core.AbstractToken):
214214
avals.append(result)
215215
else:
216-
_check_shape_dtype(result)
216+
if not hasattr(result, "shape") or not hasattr(result, "dtype"):
217+
raise ValueError(
218+
"All elements of result_shape_dtypes must have 'shape' and 'dtype' "
219+
f"attributes. Got {result} at position {idx}.")
217220
avals.append(core.ShapedArray(result.shape, result.dtype))
218221
return tuple(avals)
219222

tests/extend_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from jax._src import abstract_arrays
2929
from jax._src import api
30+
from jax._src import config
3031
from jax._src import core
3132
from jax._src import linear_util
3233
from jax._src import prng
@@ -326,6 +327,21 @@ def fun(x):
326327
"The use of ffi_call attributes requires"):
327328
jax.jit(fun).lower(jnp.ones(5)).as_text()
328329

330+
def testAllow64(self):
331+
if config.enable_x64.value:
332+
self.skipTest("Requires enable_x64=False")
333+
def fun():
334+
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
335+
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
336+
337+
def testInvalidResultType(self):
338+
with self.assertRaisesRegex(
339+
ValueError, "All elements of result_shape_dtypes.*position 0"):
340+
jex.ffi.ffi_call("test", None)()
341+
with self.assertRaisesRegex(
342+
ValueError, "All elements of result_shape_dtypes.*position 1"):
343+
jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
344+
329345

330346
def ffi_call_geqrf(x, **kwargs):
331347
if jtu.test_device_matches(["cpu"]):

0 commit comments

Comments
 (0)