-
Hi, QuestionI had to do a basic computation with a SettingsSpecifically, I had to handle the following computation: @partial(jit, static_argnames=["n_steps"])
def gd_filter_jax(sigma: jnp.array, n_steps: int, power: float):
result = jnp.zeros_like(sigma)
for k in range(1, n_steps+1):
prod_k = jnp.ones_like(sigma)
for i in range(k+2, n_steps+1):
tau_i = i ** (- power)
prod_k *= 1 - tau_i * sigma
result += prod_k * k ** (- power)
return result And I compared it with a native Python loop (without jit) and a version with lax: @partial(jit, static_argnames=["n_steps"])
def gd_filter_lax(sigma: jnp.array, n_steps: int, power: float):
result = lax.fori_loop(1, n_steps+1,
lambda k, s: s + k ** (- power) * lax.fori_loop(k+2, n_steps+1,
lambda i, p: p * (1 - i ** (- power) * sigma),
jnp.ones_like(sigma)
),
jnp.zeros_like(sigma)
)
return result I called these functions with the following parameters: sigma = np.arange(1, 100)**(-2.)
%timeit gd_filter(sigma, 100, 1.).block_until_ready() that is,
With |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think the short answer is that In your case, As an aside, here's how you could do the filter without any loops:
|
Beta Was this translation helpful? Give feedback.
I think the short answer is that
lax.fori_loop
is designed for truly sequential cases, where each individual iterate depends on the previous step, and the step function is reasonably complex (think a step of gradient descent, or an action in a reinforcement learning context). If jit were to statically unroll e.g. 1000 steps of gradient descent it would take a huge amount of time for XLA to actually compile the function. In thefori_loop
context XLA is encouraged (possibly required? Disclaimer: I'm not an XLA expert) to treat each iterate as its own individual black box, so can't make any simplifications/speedups by e.g. batching individual multiplies into a faster matrix multiply.In your…