Skip to content
Discussion options

You must be logged in to vote

Indeed, when differentiating the jnp.where, the primal values of the boolean array abs(evals) > cutoff are used to filter both the primal and tangent values. It's the same idea as if we were differentiating lambda x: x **2 if x > 0 else x at primal value x=1.0 but with tangent value x_dot=-1.0: we're linearizing around the primal point and so we want to switch based on the primal value only, then have the tangent value follow along (i.e. to go through the x ** 2 function) rather than taking its own path. In this case we're writing a jnp.where instead of an if, but it's the same logic (like differentiating lambda x: jnp.where(x > 0, x ** 2, x)).

So, super concretely, when we differentiate l…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@EricaCMitchell
Comment options

@mattjj
Comment options

Answer selected by EricaCMitchell
@EricaCMitchell
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants