Skip to content

Commit 7a6db32

Browse files
authored
fix: use has_side_effect=True in FFI calls (#44)
The [JAX docs](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) state this: > In the context of JAX transformations, Python exceptions should be considered side-effects: this means that intentionally raising an error within a pure_callback breaks the API contract, and the behavior of the resulting program is undefined. Our callbacks can raise, so switching `has_side_effect` to `True` seems like the safer option. I don't know if this has any runtime impact right now (didn't notice any while testing).
1 parent bc5040d commit 7a6db32

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tesseract_jax/primitive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _dispatch(*args: ArrayLike) -> Any:
229229
array_args,
230230
ctx.avals_in,
231231
ctx.avals_out,
232-
has_side_effect=False,
232+
has_side_effect=True,
233233
)
234234
ctx.module_context.add_keepalive(keepalive)
235235
return result

0 commit comments

Comments
 (0)