Replies: 1 comment
-
You could pass the constants by closure instead; modifying your example, it might look something like this: def main_func(a_const, b_const, c_const, d_const, e_const, variable, array:jnp.ndarray)
def compute_variables(variable, x):
# add a dozen lines of logic here
return (x+a_const-b_const*variable, variable)
scan(compute_variable, init=variable), xs=array) The general approach is: you can reference |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Sup,
I recently had to write a module where we had 25
scan
s in a single file... Well, it was not so simple because we had 5 variables in carry - and most of them were constants. Minimal reproducible example here:I believe in ML there are also hyperparameters which need to be passed and needelessly optimized for by
jit
in every line. If there is a constants, XLA may not bother optimizing functions for them in the same way asstatic_argnums
makes functions faster..Would
scan
maybe get an optionalconsts
parameter in which the constant parameters would be passed?It's a QOL feature which also may speed up
scan
and other loops computation a little bit.Is it big enough to be implemeneted?
(same goes for
fori_loop
andwhile_loop
s)Beta Was this translation helpful? Give feedback.
All reactions