Skip to content
Discussion options

You must be logged in to vote

I was able to solve this problem using the second suggestion by @mattjj so thanks, here is the solution:

@jax.jit
def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
    return jnp.sum(jnp.where(jnp.arange(len(x)) < num_elems, x, 0) *
    in_range(jnp.where(jnp.arange(len(y0)) < num_elems, y0, 0) - z0, -1., 1.) *
    in_range(jnp.where(jnp.arange(len(y1)) < num_elems, y1, 0) - z1, -.1, .1))

Now, the question is if this is a differentiable solution?

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mjhoover1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant