Replies: 1 comment 1 reply
-
This is tricky, as loop constructs are indeed somewhat limited. Idea 1: since compile times can be superlinear in the size of the program, it might be useful to interpolate between the extremes of compiling everything and only compiling primitives (since JAX's eager operation is like just compiling primitives). Concretely, you could do something like: def sequential(x):
for subseq in chunk_list(zip(funcs, params)):
x = subsequential(subseq, x)
return x
@jit
def subsequential(subseq, x):
for f, p in subseq:
x = f.forward(x, p)
return x As a variant, you could also just apply def sequential(x, params):
for f, p in zip(funcs, params):
x = jit(f.forward)(x, p)
return x Idea 2: if the functions happen to have the same input and output shapes, then you could do a Idea 3: share a repro and we can ask the XLA compiler engineers if compilation can be sped up here just by XLA improvements. They want to improve compile times, and having real benchmarks from users is helpful to that end. Maybe there are other options, but those are the first that spring to mind. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a function that loops through lists of functions and parameters, passing the output of each function to the next function. Looks like this below.
I built a pretty large deep learning model using a bunch of these types of functions. When I JIT this function, the whole model takes around 4 minutes to compile and runs in around 40 ms with sample input.
If I don't JIT this function, the model takes about 20 seconds to compile but takes 130 ms to run with the same input.
I've looked through the control flow options in
jax.lax
but I don't see how I can get this to work withfori_loop
orscan
. I'd like to JIT this function because it's 3x faster but taking 4 minutes to compile is pretty rough.Any advice on how to write the above function so I can JIT it without a super long compile time? Thanks!
Beta Was this translation helpful? Give feedback.
All reactions