-
I am trying to finish up a function, but an issue arose carry = lane, leader_dist
lane, leader_dist = lax.cond(step != 0 and step % lane_change_period == 0, body, lambda x: carry, carry) This will change the lane variable and leader distance it's used for lane changes in a highway simulator, but it never reshapes the lane jnp array. Here is the body: def body(carry):
lane, leader_dist = carry
min_target_lane = lane[road][i] - 1. * sigmoid2(lane[road][i] - 1., .5)
max_target_lane = lane[road][i] + 1. * (1. - sigmoid2(lane[road][i] - 3 + 1., .5)) # 3=number of lanes
leader_dists = jnp.zeros(3) # 3=number of lanes
carry = leader_dists, min_target_lane, max_target_lane
leader_dists, _, _ = lax.fori_loop(0, 3, inner_loop_body5, carry) # 3=number of lanes
max_leader_dist = max_pos(leader_dists, 3) # 3=number of lanes
furthest_ahead_leader_lane = select_x_where_y_equals_z(lane[road], pos[road], pos[road, i] + max_leader_dist, num_agents_per_road)
new_lane = select_x_if_larger_y_else_z(furthest_ahead_leader_lane, .5, lane[road, i])
clearance_gain = max_leader_dist - leader_dist
new_lane_thresh_factor = sigmoid(clearance_gain - min_clearance_gain)
new_lane_thresh_factor *= sigmoid(vel[road, i] - min_lane_change_vel)
new_lane = new_lane * new_lane_thresh_factor + lane[road, i] * (1. - new_lane_thresh_factor)
lane = lane.at[road, i].set(new_lane)
return lane, leader_dist This function also relies on inner_loop_body5: def inner_loop_body5(j, carry):
leader_dists, min_target_lane, max_target_lane = carry
leader_dists = leader_dists.at[j].set(min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], j + 1))
ir = in_range(j + 1, min_target_lane - .5, max_target_lane + .5)
leader_dists = leader_dists.at[j].multiply(ir)
return leader_dists, min_target_lane, max_target_lane I'm not quite sure what's causing this issue, but is it possible to get around this error? The above is an attempt to summarize the problem to create a complete minimal example. Here is a more verbose example if necessary to get the error to occur, but none of the aux functions are a problem the issues are from the lax.cond call above: import jax.numpy as jnp
from jax import jit, lax
from jax.config import config
from jax.scipy.special import logsumexp
import jax
from functools import partial
from jax.scipy.special import expit
config.update('jax_platform_name', 'cpu')
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_disable_jit', False)
lane_change_period = 25
min_clearance_gain = 10.
min_lane_change_vel = 8.
traffic_sim_sigmoid_slope = 16.0
traffic_light_ahead_sigmoid_slope = 16.0
road_end_pos = 250.
light_pos=100.
road = 0
i = 0
pos = [[40., 300., 340., 380.]]
num_agents_per_road = 4
lane = [[40., 300., 340., 380.]]
@jit
def min_pos_delta_if_equal(x, num_elems, ref_idx, y, z):
M = 1000.
@jit
def body(carry):
i, sum = carry
diff = x[i] - x[ref_idx]
neg_offset = 2 * M * (1.0 - sigmoid(diff))
diff += neg_offset
lane_shift = 1000. * (1.0 - in_range(y[i] - z, -0.5, 0.5))
diff += lane_shift
exp_arg = -diff
sum += expit(exp_arg)
return sum
@jit
def outer_body(i, sum):
return lax.cond(i == ref_idx, lambda x: x[1], body, (i, sum))
sum = lax.fori_loop(0, num_elems, outer_body, 0.)
return -jnp.log(sum)
@jit
def sigmoid(x):
sigmoid_slope = 32.
return 1.0 / (1.0 + expit(-sigmoid_slope * x))
@jit
def sigmoid2(x, x_offset):
sigmoid_slope = 32.
# return my_div(1. + jnp.exp(-sigmoid_slope * (x - x_offset)))
return 1. / (1. + expit(-sigmoid_slope * (x - x_offset)))
@jit
def in_range_unscaled(x, from_, to):
return sigmoid2(x, from_)*(1. - sigmoid2(x, to))
@jit
def in_range(x, from_, to):
return in_range_unscaled(x, from_, to) / in_range_unscaled((from_ + to) / 2.0, from_, to) # second term just for scaling
@jit
def max_pos(x, num_elems):
return logsumexp(x[:num_elems])
@jit
def select_x_where_y_equals_z(x, y, z, num_elems):
return (x[:num_elems] * in_range(y[:num_elems] - z, -1., 1.)).sum() # range_scale is 1.
@jit
def select_x_if_larger_y_else_z(x, y, z):
return sigmoid2(x, y) * x + (1.0 - sigmoid2(x, y)) * z
@jit
def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
return jnp.sum(jnp.where(jnp.arange(len(x)) < num_elems, x, 0) *
in_range(jnp.where(jnp.arange(len(y0)) < num_elems, y0, 0) - z0, -1., 1.) *
in_range(jnp.where(jnp.arange(len(y1)) < num_elems, y1, 0) - z1, -.1, .1))
@jit
def sigmoid_with_slope(x, slope):
return 1.0 / (1.0 + expit(-slope * x))
@partial(jit, static_argnums=(0,))
def inner_loop_body4(step, i, lane, pos, vel, num_agents_per_road, sum_progress, curr_light_red, road):
def inner_loop_body5(j, carry):
leader_dists, min_target_lane, max_target_lane = carry
leader_dists = leader_dists.at[j].set(min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], j + 1))
ir = in_range(j + 1, min_target_lane - .5, max_target_lane + .5)
leader_dists = leader_dists.at[j].multiply(ir)
return leader_dists, min_target_lane, max_target_lane
def body(carry):
lane, leader_dist = carry
min_target_lane = lane[road][i] - 1. * sigmoid2(lane[road][i] - 1., .5)
max_target_lane = lane[road][i] + 1. * (1. - sigmoid2(lane[road][i] - 3 + 1., .5)) # 3=number of lanes
leader_dists = jnp.zeros(3) # 3=number of lanes
carry = leader_dists, min_target_lane, max_target_lane
leader_dists, _, _ = lax.fori_loop(0, 3, inner_loop_body5, carry) # 3=number of lanes
max_leader_dist = max_pos(leader_dists, 3) # 3=number of lanes
furthest_ahead_leader_lane = select_x_where_y_equals_z(lane[road], pos[road], pos[road, i] + max_leader_dist, num_agents_per_road)
new_lane = select_x_if_larger_y_else_z(furthest_ahead_leader_lane, .5, lane[road, i])
clearance_gain = max_leader_dist - leader_dist
new_lane_thresh_factor = sigmoid(clearance_gain - min_clearance_gain)
new_lane_thresh_factor *= sigmoid(vel[road, i] - min_lane_change_vel)
new_lane = new_lane * new_lane_thresh_factor + lane[road, i] * (1. - new_lane_thresh_factor)
lane = lane.at[road, i].set(new_lane)
return lane, leader_dist
leader_dist = min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], lane[road][i])
carry = lane, leader_dist
lane, leader_dist = lax.cond(step != 0 and step % lane_change_period == 0, body, lambda x: carry, carry)
leader_dist = min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], lane[road, i]) # updated after lane change
leader_vel = select_x_where_y0_equals_z0_and_y1_equals_z1(vel[road], lane[road], lane[road][i], pos[road], pos[road][i] + leader_dist, num_agents_per_road)
light_dist = light_pos - pos[road, i]
light_ahead = sigmoid_with_slope(light_dist, 16.) # traffic_light_ahead_sigmoid_slope=16.
braking_at_light = curr_light_red * light_ahead * sigmoid_with_slope(leader_dist - light_dist, 16.) # traffic_sim_sigmoid_slope=16.
leader_dist = braking_at_light * light_dist + (1. - braking_at_light) * leader_dist
leader_vel = (1. - braking_at_light) * leader_vel
accel = idm(vel[road, i], leader_dist, leader_vel, 13.888, 2.) # 13.888=speed limit, max_accel=2.
vel = vel.at[road, i].add(accel * .1) # .1=tau=time step
vel = vel.at[road, i].multiply(sigmoid_with_slope(vel[road, i], traffic_sim_sigmoid_slope)) # allow velocities > 0 - eps only
progress = vel[road, i] * .1 # .1=tau=time step
pos = pos.at[road, i].add(progress)
sum_progress += progress
past_road_end = sigmoid_with_slope(pos[road,i] - road_end_pos, traffic_sim_sigmoid_slope)
pos = pos.at[road, i].add(-pos[road,i] * past_road_end)
return lane, pos, vel, num_agents_per_road, sum_progress, curr_light_red, road
@jit
def wrapper_inner_loop_body4(i, carry):
step, lane, pos, vel, num_agents_per_road, sum_progress, curr_light_red, road = carry
lane, pos, vel, num_agents_per_road, sum_progress, curr_light_red, road = inner_loop_body4(step, i, lane, pos, vel, num_agents_per_road, sum_progress, curr_light_red, road)
return step, lane, pos, vel, num_agents_per_road, sum_progress, curr_light_red, road
lax.fori_loop(0, num_agents_per_road - 3, wrapper_inner_loop_body4, (1, lane, pos, [[13.888, 13.888, 13.888, 13.888]], num_agents_per_road, 0.0, 0.63999973, road)) # 3 is number of lanes The first issue I was getting was |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
You jit-compile You then pass leader_dist = min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], lane[road][i]) At this point you get a new error because you're attempting to index a list with a traced value – you can fix that by casting the lists to a JAX array: pos = jnp.array([[40., 300., 340., 380.]])
lane = jnp.array([[40., 300., 340., 380.]]) Then you hit an Side-note: there's something strange in your Best of luck! |
Beta Was this translation helpful? Give feedback.
You jit-compile
wrapper_inner_loop_body4
, which means all arguments passed to it are non-static, including the first element ofcarry
, which isstep
.You then pass
step
as the first argument toinner_loop_body4
, which is marked as static, and this leads to an error. To fix that, you should avoid marking this argument as static. The reason this causes issues is because you're using Python'sand
operator with traced variables;and
attempts to eagerly convert outputs to static booleans (a behavior that is impossible to override). For that reason, JAX follows NumPy and uses the elemenwise and operator (&
) for this type of operation: