Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! The issue is that static operations like dict key lookups happen at trace-time, and lax.cond will always trace both branches (though at runtime it will execute only one branch). You can fix this by changing fb to account for the fact that it will be traced regardless of the condition; e.g.

  def fb(state):
    return state.get('A', 0)

But backing up a step, given that your condition is a statically-known quantity, there's no reason to use lax.cond and to push the logic into the compiler. You could simplify your function by evaluating this static condition at trace time at the top level; for example:

@jax.jit

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@eadadi
Comment options

@jakevdp
Comment options

@eadadi
Comment options

@jakevdp
Comment options

@eadadi
Comment options

Answer selected by eadadi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants