Jitting nested functions #15578
Replies: 2 comments 3 replies
-
Thanks for the question! I'm guessing your code is spending all its time compiling because there are large unrolled loops (i.e. Python loop construct which gets unrolled when JAX traces and compiles it). I can't tell about If that's the issue, and you still want You can also set |
Beta Was this translation helpful? Give feedback.
-
I am completely new here, but I am guessing that having the mainLoop unjitted means that Am I on the wrong here? PS: Thanks for the quick reply :) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to run a simulation with JAX for my Final Degree Project. This simulation uses two large functions to compute two sides of an equation, but these functions are interchangeable. For instance, I might have a
computeLeft()
function and acomputeLeftDimensionless()
, and I want to be able to switch them at will. Finally, amainLoop()
function solves the equation for a given ammount of iterations.So, once I've selected which functions to use for the left and right of the equation, I want to jit them, and pass them to the
mainLoop
function, which should also be jitted. I think this is how it should be done:Where
mainLoop
is something like the following:I am pretty sure I might be missing something here, but my program just gets stuck after tracing the
jitMainLoop
function, and never seems to do anything.Beta Was this translation helpful? Give feedback.
All reactions