Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,12 @@ static int my_log(PyObject *dict, Log *log) {
assign_to_dict(dict, "max_observation_distance", log->max_observation_distance);
assign_to_dict(dict, "observation_coverage", log->observation_coverage);
assign_to_dict(dict, "partner_obs_coverage", log->partner_obs_coverage);
float at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->at_fault_collision_count / log->collisions_per_agent : 0.0f;
assign_to_dict(dict, "at_fault_collision_pct", at_fault_collision_pct);
float not_at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->not_at_fault_collision_count / log->collisions_per_agent : 0.0f;
Comment on lines +475 to +479
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The at-fault percentages are computed using log->collisions_per_agent as the denominator. Since you already track at_fault_collision_count and not_at_fault_collision_count, the most direct/consistent denominator is their sum; using collisions_per_agent risks drifting if the two counters ever differ from the collision counter and makes the intent less clear.

Suggested change
float at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->at_fault_collision_count / log->collisions_per_agent : 0.0f;
assign_to_dict(dict, "at_fault_collision_pct", at_fault_collision_pct);
float not_at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->not_at_fault_collision_count / log->collisions_per_agent : 0.0f;
float total_collision_count = log->at_fault_collision_count + log->not_at_fault_collision_count;
float at_fault_collision_pct =
(total_collision_count > 0) ? log->at_fault_collision_count / total_collision_count : 0.0f;
assign_to_dict(dict, "at_fault_collision_pct", at_fault_collision_pct);
float not_at_fault_collision_pct =
(total_collision_count > 0) ? log->not_at_fault_collision_count / total_collision_count : 0.0f;

Copilot uses AI. Check for mistakes.
assign_to_dict(dict, "not_at_fault_collision_pct", not_at_fault_collision_pct);
// assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error);
return 0;
}
22 changes: 22 additions & 0 deletions pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,27 @@ int unnormalize_traffic_light_state(int norm_state) {
}
}

/**
* @brief Categorizes the type of collision an agent is involved in.
*
* - STATIONARY_AGENT_COLLISION: ego hit a stationary partner agent
*
* - STATIONARY_EGO_COLLISION: ego was stationary when hit by another agent
*
* - ACTIVE_FRONT_COLLISION: ego collided with something ahead of it
*
* - ACTIVE_REAR_COLLISION: ego was hit from behind while moving
*
* - ACTIVE_LATERAL_COLLISION: ego was involved in a side collision while moving
*/
typedef enum {
STATIONARY_AGENT_COLLISION,
STATIONARY_EGO_COLLISION,
ACTIVE_FRONT_COLLISION,
ACTIVE_REAR_COLLISION,
ACTIVE_LATERAL_COLLISION,
} CollisionType;
Comment on lines +142 to +161
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enum doc comment references STATIONARY_PARTNER_COLLISION, but the actual enum value is named STATIONARY_AGENT_COLLISION. Please align the documentation and the enum name to avoid confusion for downstream users.

Copilot uses AI. Check for mistakes.

struct Agent {
int id;
int type;
Expand Down Expand Up @@ -195,6 +216,7 @@ struct Agent {
float goals_attempted_this_episode; // goals reached + last goal(if this segment can be judged as an attempt)
int current_goal_reached;
int collided_before_goal;
CollisionType current_collision_type;
float init_goal_x;
float init_goal_y;
float init_goal_z;
Expand Down
80 changes: 73 additions & 7 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
#define LANE_DISTANCE_NORMALIZATION 4.0f
#define LANE_SWITCH_THRESHOLD 0.05f // Hysteresis: new lane must be 5% better to switch
#define LANE_ALIGN_COS_THRESHOLD 0.5f
#define ZERO_THRESHOLD 1e-6f

// Minimum distance to goal position
#define MIN_DISTANCE_TO_GOAL 2.0f
Expand Down Expand Up @@ -186,10 +187,12 @@
// Offsets
#define COLLISION_RANGE 5
#define Z_RANGE 3
#define Z_BUFFER 4.0f // 4.0m buffer for different z level checking
#define SPEED_LIMIT 20.0f // Hardcoded speed limit value
#define COMFORT_JERK_THRESHOLD 5.0f // For JERK model comfort
#define COMFORT_ACCEL_THRESHOLD 3.0f // For JERK and CLASSIC model comfort
#define Z_BUFFER 4.0f // 4.0m buffer for different z level checking
#define SPEED_LIMIT 20.0f // Hardcoded speed limit value
#define COMFORT_JERK_THRESHOLD 5.0f // For JERK model comfort
#define COMFORT_ACCEL_THRESHOLD 3.0f // For JERK and CLASSIC model comfort
#define STATIONARY_SPEED_THRESHOLD 0.05f // Speed below which we consider the agent stationary
#define HEAD_ON_COLLISION_ANGLE_THRESHOLD 30.0f // Degrees within which a collision is considered head-on

// Metrics Heuristics
#define MIN_GOAL_SEGMENT_TIME_TO_ANALYZE_AGENT 1.0f
Expand Down Expand Up @@ -290,9 +293,11 @@ struct Log {
float total_distance_travelled;
float total_infractions;
float avg_speed_per_agent;
float max_observation_distance; // average max observation distance
float observation_coverage; // percentage of entities in obs window seen on average
float partner_obs_coverage; // % of partners within radius that fit in the obs slots
float max_observation_distance; // average max observation distance
float observation_coverage; // percentage of entities in obs window seen on average
float partner_obs_coverage; // % of partners within radius that fit in the obs slots
float at_fault_collision_count; // count of at-fault collisions (ratio computed in my_log)
float not_at_fault_collision_count; // count of not at-fault collisions (ratio computed in my_log)
};

typedef struct GridMapEntity GridMapEntity;
Expand Down Expand Up @@ -1563,6 +1568,56 @@ int collision_check(Drive *env, int agent_idx) {
return car_collided_with_index;
}

float dot_product(float v1_x, float v1_y, float v2_x, float v2_y) { return v1_x * v2_x + v1_y * v2_y; }

float get_partner_relative_angle(Drive *env, int ego_agent_idx, int partner_agent_idx) {
Agent *ego = &env->agents[ego_agent_idx];
Agent *partner = &env->agents[partner_agent_idx];

float dx = partner->sim_x - ego->sim_x;
float dy = partner->sim_y - ego->sim_y;
float norm = sqrtf(dx * dx + dy * dy);

if (norm < ZERO_THRESHOLD) {
return 0.0f; // Agents are at the same position, treat as head-on
}

float ego_x = cosf(ego->sim_heading);
float ego_y = sinf(ego->sim_heading);

float relative_angle = acosf(dot_product(dx, dy, ego_x, ego_y) / norm);

Comment thread
mpragnay marked this conversation as resolved.
return relative_angle;
}

void classify_collision_type(Drive *env, int ego_agent_idx, int collided_with_idx) {
Agent *ego = &env->agents[ego_agent_idx];
Agent *collided_entity = &env->agents[collided_with_idx];

float ego_speed_magnitude = sqrtf(ego->sim_vx * ego->sim_vx + ego->sim_vy * ego->sim_vy);
float partner_speed_magnitude =
sqrtf(collided_entity->sim_vx * collided_entity->sim_vx + collided_entity->sim_vy * collided_entity->sim_vy);

// Stationary cases
if (ego_speed_magnitude < STATIONARY_SPEED_THRESHOLD) {
ego->current_collision_type = STATIONARY_EGO_COLLISION;
return;
} else if (partner_speed_magnitude < STATIONARY_SPEED_THRESHOLD) {
ego->current_collision_type = STATIONARY_AGENT_COLLISION;
return;
}

// Both agents are moving
float relative_angle = get_partner_relative_angle(env, ego_agent_idx, collided_with_idx);
if (relative_angle < HEAD_ON_COLLISION_ANGLE_THRESHOLD * (M_PI / 180.0f)) {
ego->current_collision_type = ACTIVE_FRONT_COLLISION;
} else if (relative_angle > (180.0f - HEAD_ON_COLLISION_ANGLE_THRESHOLD) * (M_PI / 180.0f)) {
ego->current_collision_type = ACTIVE_REAR_COLLISION;
} else {
ego->current_collision_type = ACTIVE_LATERAL_COLLISION;
}
}

bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) {
if (fmax(p1[0], p2[0]) < fmin(q1[0], q2[0]) || fmin(p1[0], p2[0]) > fmax(q1[0], q2[0]) ||
fmax(p1[1], p2[1]) < fmin(q1[1], q2[1]) || fmin(p1[1], p2[1]) > fmax(q1[1], q2[1]))
Expand Down Expand Up @@ -2656,6 +2711,17 @@ void compute_agent_metrics(Drive *env, int agent_idx) {
car_collided_with_index = collision_check(env, agent_idx);
if (car_collided_with_index != -1) {
collided = VEHICLE_COLLISION;
agent->collision_state = collided;

classify_collision_type(env, agent_idx, car_collided_with_index);

// Determine at-fault
CollisionType ct = env->agents[agent_idx].current_collision_type;
int at_fault =
(ct == STATIONARY_AGENT_COLLISION || ct == ACTIVE_FRONT_COLLISION || ct == ACTIVE_LATERAL_COLLISION);
int not_fault_ct = (ct == ACTIVE_REAR_COLLISION || ct == STATIONARY_EGO_COLLISION);
env->log.at_fault_collision_count += at_fault;
env->log.not_at_fault_collision_count += not_fault_ct;
}
agent->collision_state = collided;

Expand Down
Loading