-
Hello, I am working on converting some highway simulation code from C++ using the Adept library to python using JAX autograd. However, I believe I made all the necessary changes to get value_and_grad to work yet the gradient is returning 0. I don't think there are any issues in the code because it returns the expected value output, but the correct gradient isn't being returned. I tried to make the code as self-contained as possible without losing the functionality of the code. Any advice on what might be happening here, or how to go about debugging JAX for a grad equal to 0 when the output relies on the given input? Thanks! I know that a zero gradient means that the output doesn't depend on the given input, but it should the same code produced a gradient when it was implemented in C++ using Adept. Code: import jax.numpy as jnp
from jax._src.api import value_and_grad
# again, this only works on startup!
from jax.config import config
config.update('jax_platform_name', 'cpu')
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
def sigmoid(x):
sigmoid_slope = 32.
return 1.0 / (1.0 + jnp.exp(-sigmoid_slope * x))
def sigmoid2(x, x_offset):
sigmoid_slope = 32.
return 1. / (1. + jnp.exp(-sigmoid_slope * (x - x_offset)))
def in_range_unscaled(x, from_, to):
return sigmoid2(x, from_)*(1. - sigmoid2(x, to))
def in_range(x, from_, to):
return in_range_unscaled(x, from_, to) / in_range_unscaled((from_ + to) / 2.0, from_, to) # second term just for scaling
def sigmoid_with_slope(x, slope):
return 1.0 / (1.0 + jnp.exp(-slope * x))
def periodic_step(x, period, offset, slope):
return sigmoid_with_slope(jnp.sin((x - offset) * 3.141593 / period), slope)
def light_red(x, overall_period, on_period, offset, slope = -1.0):
if slope < 0.0:
slope = 32.
return periodic_step(x, on_period, offset, slope)
def min_pos_delta_if_equal(x, num_elems, ref_idx, y, z):
sum = 0.0
M = 1000.
for i in range(num_elems):
if i == ref_idx: continue
diff = x[i] - x[ref_idx]
neg_offset = 2 * M * (1.0 - sigmoid(diff))
diff += neg_offset
lane_shift = 1000. * (1.0 - in_range(y[i] - z, -0.5, 0.5))
diff += lane_shift
exp_arg = -diff
sum += jnp.exp(exp_arg)
return -jnp.log(sum)
def max_pos(x, num_elems):
sum = 0.0
for i in range(num_elems):
sum += jnp.exp(x[i])
return jnp.log(sum)
def select_x_where_y_equals_z(x, y, z, num_elems):
sum = 0.0
range_scale = 1.0
for i in range(num_elems):
factor = in_range(y[i] - z, -range_scale, range_scale)
sum += x[i] * factor
return sum
def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
sum = 0.0
for i in range(num_elems):
ir0 = in_range(y0[i] - z0, -1.0, 1.0)
ir1 = in_range(y1[i] - z1, -0.1, 0.1)
factor = ir0 * ir1
sum += x[i] * factor
return sum
def select_x_if_larger_y_else_z(x, y, z):
return sigmoid2(x, y) * x + (1.0 - sigmoid2(x, y)) * z
def ipow(x, p):
r = 1.0
for i in range(p):
r *= x
return r
def idm(vel, leader_dist, leader_vel, speed_limit, max_accel):
s0 = 5.0 # min spacing
t = 3.0; # desired time headway
accel_term0 = ipow(vel / speed_limit, 4)
accel_term1 = ipow((s0 + vel * t + (vel * (vel - leader_vel) / (2 * max_accel))) / leader_dist, 2)
return max_accel * (1.0 - accel_term0 - accel_term1)
def traffic_sim_smoothed(x, num_agents_per_road):
# Defining locals
dummy_pos = 300.
init_pos_inc = 40
speed_limit = 13.888
road_end_pos = 250.
max_ts = 10.
num_lanes = 3
num_roads = 1
tau = .1
light_period = 10.
light_red_duration = 5.
min_clearance_gain = 10.
min_lane_change_vel = 8.
max_accel = 2.
lane_change_period = 25
traffic_sim_sigmoid_slope = 16.0
light_pos = 100.
traffic_light_ahead_sigmoid_slope = 16.0
# End definitions
num_agents_per_road += num_lanes
pos = jnp.zeros((num_roads, num_agents_per_road))
vel = jnp.zeros((num_roads, num_agents_per_road))
lane = jnp.zeros((num_roads, num_agents_per_road))
for road in range(num_roads):
i = 1
while i <= num_lanes: # for (int i = 1; i <= num_lanes; i++) {
idx = num_agents_per_road - num_lanes + i - 1
lane = lane.at[road, idx].set(1 + (i - 1) % num_lanes)
pos = pos.at[road, idx].set(dummy_pos + (i - 1) * init_pos_inc)
vel = vel.at[road, idx].set(speed_limit)
i += 1
for i in range(num_agents_per_road - num_lanes):
pos = pos.at[road, i].set(((i + 1) * init_pos_inc) % int(road_end_pos))
vel = vel.at[road, i].set(speed_limit)
lane = lane.at[road, i].set(1. + i % num_lanes)
sum_progress = 0.0
light_offset = x
step = 0
while tau * step <= max_ts:
curr_light_red = light_red(tau * step, light_period, light_red_duration, light_offset, traffic_sim_sigmoid_slope)
for road in range(num_roads):
for i in range(num_agents_per_road - num_lanes):
leader_dist = min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], lane[road][i])
if step != 0 and step % lane_change_period == 0:
min_target_lane = lane[road][i] - 1. * sigmoid2(lane[road][i] - 1., .5)
max_target_lane = lane[road][i] + 1. * (1. - sigmoid2(lane[road][i] - num_lanes + 1., .5))
# get vel of leader on each lane,
# then get vel of furthest ahead of those,
# then get corresponding lane
leader_dists = jnp.zeros(num_lanes)
for j in range(num_lanes):
leader_dists = leader_dists.at[j].set(min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], j + 1))
ir = in_range(j + 1, min_target_lane - .5, max_target_lane + .5)
leader_dists = leader_dists.at[j].multiply(ir)
max_leader_dist = max_pos(leader_dists, num_lanes)
furthest_ahead_leader_lane = select_x_where_y_equals_z(lane[road], pos[road], pos[road, i] + max_leader_dist, num_agents_per_road)
new_lane = select_x_if_larger_y_else_z(furthest_ahead_leader_lane, .5, lane[road, i])
clearance_gain = max_leader_dist - leader_dist
new_lane_thresh_factor = sigmoid(clearance_gain - min_clearance_gain)
new_lane_thresh_factor *= sigmoid(vel[road, i] - min_lane_change_vel)
new_lane = new_lane * new_lane_thresh_factor + lane[road, i] * (1. - new_lane_thresh_factor)
lane = lane.at[road, i].set(new_lane)
leader_dist = min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], lane[road, i]) # updated after lane change
leader_vel = select_x_where_y0_equals_z0_and_y1_equals_z1(vel[road], lane[road], lane[road][i], pos[road], pos[road][i] + leader_dist, num_agents_per_road)
light_dist = light_pos - pos[road, i]
light_ahead = sigmoid_with_slope(light_dist, traffic_light_ahead_sigmoid_slope)
braking_at_light = curr_light_red * light_ahead * sigmoid_with_slope(leader_dist - light_dist, traffic_sim_sigmoid_slope)
leader_dist = braking_at_light * light_dist + (1. - braking_at_light) * leader_dist
leader_vel = (1. - braking_at_light) * leader_vel
accel = idm(vel[road, i], leader_dist, leader_vel, speed_limit, max_accel)
vel = vel.at[road, i].add(accel.primal * tau)
vel = vel.at[road, i].multiply(sigmoid_with_slope(vel[road, i], traffic_sim_sigmoid_slope)) # allow velocities > 0 - eps only
progress = vel[road, i] * tau
pos = pos.at[road, i].add(progress)
sum_progress += progress
past_road_end = sigmoid_with_slope(pos[road,i] - road_end_pos, traffic_sim_sigmoid_slope)
pos = pos.at[road, i].add(-pos[road,i] * past_road_end)
step += 1
return sum_progress
if __name__ == "__main__":
num_agents_per_road = 1
tau = .1
x_init = tau/4 # Initial x time step
print("x, y, dydx")
value, gradient = value_and_grad(traffic_sim_smoothed)(x_init, num_agents_per_road)
print(str(x_init) + ", ", value, ", " + str(gradient)) Edited: from operator import ne, neg
import jax.numpy as jnp
from jax._src.api import value_and_grad
from jax import grad, jit, vmap, pmap, lax
# again, this only works on startup!
from jax.config import config
from jax.scipy.special import logsumexp
config.update('jax_platform_name', 'cpu')
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
@jit
def sigmoid(x):
sigmoid_slope = 32.
return 1.0 / (1.0 + jnp.exp(-sigmoid_slope * x))
@jit
def sigmoid2(x, x_offset):
sigmoid_slope = 32.
return 1. / (1. + jnp.exp(-sigmoid_slope * (x - x_offset)))
@jit
def in_range_unscaled(x, from_, to):
return sigmoid2(x, from_)*(1. - sigmoid2(x, to))
@jit
def in_range(x, from_, to):
return in_range_unscaled(x, from_, to) / in_range_unscaled((from_ + to) / 2.0, from_, to) # second term just for scaling
@jit
def sigmoid_with_slope(x, slope):
return 1.0 / (1.0 + jnp.exp(-slope * x))
@jit
def periodic_step(x, period, offset, slope):
return sigmoid_with_slope(jnp.sin((x - offset) * jnp.pi / period), slope)
def light_red(x, on_period, offset, slope = -1.0):
if slope < 0.0:
slope = 32.
return periodic_step(x, on_period, offset, slope)
def min_pos_delta_if_equal(x, num_elems, ref_idx, y, z):
sum = 0.0
M = 1000.
# def cond(pred, true_fun, operand):
# if pred:
# return body_fun(operand)
# def body_fun(i, sum):
# if i != ref_idx:
# diff = x[i] - x[ref_idx]
# neg_offset = 2 * M * (1. - sigmoid(diff))
# diff += neg_offset
# lane_shift = 1000. * (1.0 - in_range(y[i] - z, -0.5, 0.5))
# diff += lane_shift
# exp_arg = -diff
# sum += jnp.exp(exp_arg)
# return sum
# sum = lax.fori_loop(0, num_elems, body_fun, 0)
# return -jnp.log(sum)
for i in range(num_elems):
if i == ref_idx: continue
diff = x[i] - x[ref_idx]
neg_offset = 2 * M * (1.0 - sigmoid(diff))
diff += neg_offset
lane_shift = 1000. * (1.0 - in_range(y[i] - z, -0.5, 0.5))
diff += lane_shift
exp_arg = -diff
sum += jnp.exp(exp_arg)
return -jnp.log(sum)
# return -jnp.logsumexp(x[:num_elems])
def max_pos(x, num_elems):
return logsumexp(x[:num_elems])
def select_x_where_y_equals_z(x, y, z, num_elems):
range_scale = 1.0
return (x[:num_elems] * in_range(y[:num_elems] - z, -range_scale, range_scale)).sum()
def select_x_where_y0_equals_z0_and_y1_equals_z1(x, y0, z0, y1, z1, num_elems):
return (x[:num_elems] * in_range(y0[:num_elems] - z0, -1.0, 1.0) * in_range(y1[:num_elems] - z1, -0.1, 0.1)).sum()
@jit
def select_x_if_larger_y_else_z(x, y, z):
return sigmoid2(x, y) * x + (1.0 - sigmoid2(x, y)) * z
@jit
def idm(vel, leader_dist, leader_vel, speed_limit, max_accel):
s0 = 5.0 # min spacing
t = 3.0; # desired time headway
accel_term0 = jnp.power(vel / speed_limit, 4)
accel_term1 = jnp.power((s0 + vel * t + (vel * (vel - leader_vel) / (2 * max_accel))) / leader_dist, 2)
return max_accel * (1.0 - accel_term0 - accel_term1)
def traffic_sim_smoothed(x, num_agents_per_road):
# Defining locals
speed_limit = 13.888
max_ts = 10.
num_lanes = 3
num_roads = 1
tau = .1
light_red_duration = 5.
max_accel = 2.
traffic_sim_sigmoid_slope = 16.0
light_pos = 100.
traffic_light_ahead_sigmoid_slope = 16.0
# End definitions
num_agents_per_road += num_lanes
pos = jnp.zeros((num_roads, num_agents_per_road))
vel = jnp.zeros((num_roads, num_agents_per_road))
lane = jnp.zeros((num_roads, num_agents_per_road))
sum_progress = 0.0
light_offset = x
step = 0
while tau * step <= max_ts:
curr_light_red = light_red(tau * step, light_red_duration, light_offset, traffic_sim_sigmoid_slope)
for road in range(num_roads):
for i in range(num_agents_per_road - num_lanes):
leader_dist = min_pos_delta_if_equal(pos[road], num_agents_per_road, i, lane[road], lane[road][i])
leader_vel = select_x_where_y0_equals_z0_and_y1_equals_z1(vel[road], lane[road], lane[road][i], pos[road], pos[road][i] + leader_dist, num_agents_per_road)
light_dist = light_pos - pos[road, i]
light_ahead = sigmoid_with_slope(light_dist, traffic_light_ahead_sigmoid_slope)
braking_at_light = curr_light_red * light_ahead * sigmoid_with_slope(leader_dist - light_dist, traffic_sim_sigmoid_slope)
leader_dist = braking_at_light * light_dist + (1. - braking_at_light) * leader_dist
accel = idm(vel[road, i], leader_dist, leader_vel, speed_limit, max_accel)
vel = vel.at[road, i].add(accel.primal * tau)
progress = vel[road, i] * tau
sum_progress += progress
step += 1
return sum_progress
if __name__ == "__main__":
num_agents_per_road = 1
tau = .1
x_init = tau/4 # Initial x time step
print("x, y, dydx")
value, gradient = value_and_grad(traffic_sim_smoothed)(x_init, num_agents_per_road)
print(str(x_init) + ", ", value, ", " + str(gradient)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
So there's a lot in your code and I don't think I'll have a chance to debug it in detail, but one red flag to me is your reliance on Python loops. Particularly this one: def max_pos(x, num_elems):
sum = 0.0
for i in range(num_elems):
sum += jnp.exp(x[i])
return jnp.log(sum) Due to accumulation of floating-point roundoff error, I wouldn't be surprised if this is what is causing your zero gradient. In general, if you are using JAX (or numpy) and find yourself looping over array values, you will find better and faster results by using built-in array-oriented functions. So, for example, you could replace this function with this: def max_pos(x, num_elems):
return jnp.log(jnp.exp(x[:num_elems]).sum()) But even then, summing the exponentiation of an array and then taking a log is prone to issues with floating point roundoff errors, and for this reason JAX provides the from jax.scipy.special import logsumexp
def max_pos(x, num_elems):
return logsumexp(x[:num_elems]) I would suggest going through your code and making this kind of change wherever you are looping over the contents of an array. You may find that the zero gradient issue goes away once you do, due to the improved numerical accuracy of vector-based functions and their autodiff rules. As a bonus, you'll end up with much more performant code as well. |
Beta Was this translation helpful? Give feedback.
Could you paste your code in the question rather than linking to a zip file?Thanks!So there's a lot in your code and I don't think I'll have a chance to debug it in detail, but one red flag to me is your reliance on Python loops. Particularly this one:
Due to accumulation of floating-point roundoff error, I wouldn't be surprised if this is what is causing your zero gradient. In general, if you are using JAX (or numpy) and find yourself looping over array values, you will find better and faster results by using built-in array-oriented functions. So, for example, …