-
Hello, I want to understand why the following code yields Key Error ('A').: @jax.jit
def r(state):
def fa(state):
return 0
def fb(state):
return state['A']
return lax.cond('A' not in state, fa, fb, state)
state = {}
state['A'] = r(state) The compiled version should fit any version of dictionary, isn't it? The reason I need something like this to work is because I want to get something like: @jax.jit
def r(state):
def fa(state):
return 0
def fb(state):
return state['A']
return lax.cond('A' not in state, fa, fb, state)
state = {}
state['A'] = r(state)
r(state) #now the condition takes the other branch, and it should be valid because 'A' is in state |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Hi - thanks for the question! The issue is that static operations like dict key lookups happen at trace-time, and def fb(state):
return state.get('A', 0) But backing up a step, given that your condition is a statically-known quantity, there's no reason to use @jax.jit
def r(state):
return state.get('A', 0) |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! The issue is that static operations like dict key lookups happen at trace-time, and
lax.cond
will always trace both branches (though at runtime it will execute only one branch). You can fix this by changingfb
to account for the fact that it will be traced regardless of the condition; e.g.But backing up a step, given that your condition is a statically-known quantity, there's no reason to use
lax.cond
and to push the logic into the compiler. You could simplify your function by evaluating this static condition at trace time at the top level; for example:@jax.jit …