diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index 1916035..102e69a 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -229,7 +229,7 @@ def _dispatch(*args: ArrayLike) -> Any: array_args, ctx.avals_in, ctx.avals_out, - has_side_effect=False, + has_side_effect=True, ) ctx.module_context.add_keepalive(keepalive) return result