static_argnums and JAX control flows #14818
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment 10 replies
-
There's no way to specify static variables here because scanned-over arguments cannot be static by definition (they change from iteration to iteration). The only way to have static variables in the scan body function is if they don't participate in the scan, and you can generally do this via closure, for example: import jax
from functools import partial
def f(c, x, flag): # flag must be static
if flag:
return c + 1, x
else:
return c, x
jax.lax.scan(partial(f, flag=True), 0, None, 10) |
Beta Was this translation helpful? Give feedback.
10 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
When one explicitly use jit as decorator of a function, he/she has the possibility to set some arguments as static.
But, when dealing with
jax.lax.scan(body,init,xs)
(idem for other control flows methods). the compilation of thebody
function is done w/o any control. Is there a way to pass that certain element ofinit
are static?Beta Was this translation helpful? Give feedback.
All reactions