-
Hi, I am having an issue related to this function which deals with dynamic slicing, the function was originally: def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
return (x[:num_elems] * in_range(y0[:num_elems] - z0, -1.0, 1.0) * in_range(y1[:num_elems] - z1, -0.1, 0.1)).sum() To run, but note this run statement will work because these values aren't Traced ShapedArray with DynamicJaxprTrace: select_x_where_y0_equals_z0_and_y1_equals_z1(jnp.array([13.888, 13.888, 13.888, 13.888]), jnp.array([1., 1., 2., 3.]), 1.0, jnp.array([ 40., 300., 340., 380.]), 300.0, 4) However, this wasn't working since the num_elems is dynamically sizing the arrays. So, I took the error advice and attempted using lax.dynamic_slice, causing a new error. import jax.numpy as jnp
from jax._src.api import value_and_grad
from jax import jit, lax
from jax.scipy.special import logsumexp
import jax
import functools
@functools.partial(jax.jit, static_argnums=(5,))
def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
return jnp.sum(lax.dynamic_slice(x, (0,), (num_elems,)) *
in_range(lax.dynamic_slice(y0, (0,), (num_elems,)) - z0, -1., 1) *
in_range(lax.dynamic_slice(y1, (0,), (num_elems,)) - z1, -.1, .1)) Here are the function dependencies: @jit
def sigmoid2(x, x_offset):
sigmoid_slope = 32.
return 1. / (1. + jnp.exp(-sigmoid_slope * (x - x_offset)))
@jit
def in_range_unscaled(x, from_, to):
return sigmoid2(x, from_)*(1. - sigmoid2(x, to))
@jit
def in_range(x, from_, to):
return in_range_unscaled(x, from_, to) / in_range_unscaled((from_ + to) / 2.0, from_, to) # second term just for scaling I believe going off the discussion of #1007 this way of going about it should work, but yet I'm getting |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I was able to solve this problem using the second suggestion by @mattjj so thanks, here is the solution: @jax.jit
def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
return jnp.sum(jnp.where(jnp.arange(len(x)) < num_elems, x, 0) *
in_range(jnp.where(jnp.arange(len(y0)) < num_elems, y0, 0) - z0, -1., 1.) *
in_range(jnp.where(jnp.arange(len(y1)) < num_elems, y1, 0) - z1, -.1, .1)) Now, the question is if this is a differentiable solution? |
Beta Was this translation helpful? Give feedback.
I was able to solve this problem using the second suggestion by @mattjj so thanks, here is the solution:
Now, the question is if this is a differentiable solution?