Print Value in a non-Jitted Function? #9913
-
Hi, I want to debug a function. I am passing in Jax objects and using Still, when I print an array, I only get something like |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
It is because that DEBUG = True
def scan(f, init, xs, length=None, **kwargs):
if not DEBUG:
return lax.scan(f, init, xs, length, **kwargs)
carry = init
xs_flat = jax.tree_flatten(xs)[0]
lengths = [x.shape[0] for x in xs_flat]
if length is None:
length = lengths[0] # not support xs is None when length is None
assert all(length == l for l in lengths)
def xs_generator():
for i in range(length):
yield jax.tree_map(lambda x: x[i], xs)
ys = []
for x in xs_generator():
carry, y = f(carry, x)
ys.append(y)
return carry, jax.tree_map(lambda *x: jnp.stack(x), *ys) |
Beta Was this translation helpful? Give feedback.
-
import jax
import jax.numpy as jnp
def f(x):
def body_fun(carry, x):
print(x)
return carry, x
return jax.lax.scan(body_fun, 1, x)
f(jnp.arange(4))
# Output:
# Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
with jax.disable_jit():
f(jnp.arange(4))
# Output:
# 0
# 1
# 2
# 3 |
Beta Was this translation helpful? Give feedback.
scan
will JIT compile the body function by default. If you want to disable this behavior, you can execute it in adisable_jit
context; for example: