Skip to content
Discussion options

You must be logged in to vote

scan will JIT compile the body function by default. If you want to disable this behavior, you can execute it in a disable_jit context; for example:

import jax
import jax.numpy as jnp

def f(x):
  def body_fun(carry, x):
    print(x)
    return carry, x
  return jax.lax.scan(body_fun, 1, x)

f(jnp.arange(4))
# Output:
# Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>

with jax.disable_jit():
  f(jnp.arange(4))
# Output:
# 0
# 1
# 2
# 3

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@YouJiacheng
Comment options

@YouJiacheng
Comment options

Answer selected by jacob-jacob-jacob
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants