Replies: 1 comment
-
f_wrapped = jax.jit(f)
f_wrapped.lower(*example_args).compile() But cache cannot be persisted(this experimental feature is only available on TPU) now. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm working with a function which takes in a range of shapes (e.g. i know the shape is somewhere in between (1,) - (200,), inclusive). The goal of this function is to pad the arrays so that it is available to downstream tasks in a static shape (i.e. basically a mask). However, executions of this function are becoming very expensive (called many repeated times), so I'd like to JIT compile it.
I'm wondering if Jax supports the ability to pre-compile functions for this entire range of shapes, so that I can avoid having to do expensive JITs during execution.
Alternatively, is there a way to (in parallel) quickly compile many versions of a function? I can only seem to utilize atmost one core on my cpu during JIT compilation, but it'd be nice if there was a way to quickly burst compile over the entire range before my program's runtime.
Beta Was this translation helpful? Give feedback.
All reactions