We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 53bff86 commit bb16d5aCopy full SHA for bb16d5a
jax/_src/ad_checkpoint.py
@@ -318,6 +318,9 @@ def foo(x, y):
318
``jax.ensure_compile_time_eval``), it may be easier to compute some values
319
outside the :func:`jax.checkpoint`-decorated function and then close over them.
320
"""
321
+ if isinstance(static_argnums, int):
322
+ static_argnums = static_argnums,
323
+
324
@wraps(fun)
325
@api_boundary
326
def fun_remat(*args, **kwargs):
0 commit comments