|
27 | 27 |
|
28 | 28 | from jax._src import abstract_arrays |
29 | 29 | from jax._src import api |
| 30 | +from jax._src import config |
30 | 31 | from jax._src import core |
31 | 32 | from jax._src import linear_util |
32 | 33 | from jax._src import prng |
@@ -326,6 +327,21 @@ def fun(x): |
326 | 327 | "The use of ffi_call attributes requires"): |
327 | 328 | jax.jit(fun).lower(jnp.ones(5)).as_text() |
328 | 329 |
|
| 330 | + def testAllow64(self): |
| 331 | + if config.enable_x64.value: |
| 332 | + self.skipTest("Requires enable_x64=False") |
| 333 | + def fun(): |
| 334 | + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))() |
| 335 | + self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text()) |
| 336 | + |
| 337 | + def testInvalidResultType(self): |
| 338 | + with self.assertRaisesRegex( |
| 339 | + ValueError, "All elements of result_shape_dtypes.*position 0"): |
| 340 | + jex.ffi.ffi_call("test", None)() |
| 341 | + with self.assertRaisesRegex( |
| 342 | + ValueError, "All elements of result_shape_dtypes.*position 1"): |
| 343 | + jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() |
| 344 | + |
329 | 345 |
|
330 | 346 | def ffi_call_geqrf(x, **kwargs): |
331 | 347 | if jtu.test_device_matches(["cpu"]): |
|
0 commit comments