Skip to content

Commit bb16d5a

Browse files
mattjjGoogle-ML-Automation
authored andcommitted
fix bug with jax.remat static_argnums not supporting int
PiperOrigin-RevId: 707600082
1 parent 53bff86 commit bb16d5a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def foo(x, y):
318318
``jax.ensure_compile_time_eval``), it may be easier to compute some values
319319
outside the :func:`jax.checkpoint`-decorated function and then close over them.
320320
"""
321+
if isinstance(static_argnums, int):
322+
static_argnums = static_argnums,
323+
321324
@wraps(fun)
322325
@api_boundary
323326
def fun_remat(*args, **kwargs):

0 commit comments

Comments
 (0)