Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,12 @@ static int my_log(PyObject *dict, Log *log) {
assign_to_dict(dict, "max_observation_distance", log->max_observation_distance);
assign_to_dict(dict, "observation_coverage", log->observation_coverage);
assign_to_dict(dict, "partner_obs_coverage", log->partner_obs_coverage);
float at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->at_fault_collision_count / log->collisions_per_agent : 0.0f;
assign_to_dict(dict, "at_fault_collision_pct", at_fault_collision_pct);
float not_at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->not_at_fault_collision_count / log->collisions_per_agent : 0.0f;
Comment on lines +475 to +479
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The at-fault percentages are computed using log->collisions_per_agent as the denominator. Since you already track at_fault_collision_count and not_at_fault_collision_count, the most direct/consistent denominator is their sum; using collisions_per_agent risks drifting if the two counters ever differ from the collision counter and makes the intent less clear.

Suggested change
float at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->at_fault_collision_count / log->collisions_per_agent : 0.0f;
assign_to_dict(dict, "at_fault_collision_pct", at_fault_collision_pct);
float not_at_fault_collision_pct =
(log->collisions_per_agent > 0) ? log->not_at_fault_collision_count / log->collisions_per_agent : 0.0f;
float total_collision_count = log->at_fault_collision_count + log->not_at_fault_collision_count;
float at_fault_collision_pct =
(total_collision_count > 0) ? log->at_fault_collision_count / total_collision_count : 0.0f;
assign_to_dict(dict, "at_fault_collision_pct", at_fault_collision_pct);
float not_at_fault_collision_pct =
(total_collision_count > 0) ? log->not_at_fault_collision_count / total_collision_count : 0.0f;

Copilot uses AI. Check for mistakes.
assign_to_dict(dict, "not_at_fault_collision_pct", not_at_fault_collision_pct);
// assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error);
return 0;
}
22 changes: 22 additions & 0 deletions pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,27 @@ int unnormalize_traffic_light_state(int norm_state) {
}
}

/**
* @brief Categorizes the type of collision an agent is involved in.
*
* - STATIONARY_PARTNER_COLLISION: ego hit a stationary partner agent
*
* - STATIONARY_EGO_COLLISION: ego was stationary when hit by another agent
*
* - ACTIVE_FRONT_COLLISION: ego collided with something ahead of it
*
* - ACTIVE_REAR_COLLISION: ego was hit from behind while moving
*
* - ACTIVE_LATERAL_COLLISION: ego was involved in a side collision while moving
*/
typedef enum {
STATIONARY_AGENT_COLLISION,
STATIONARY_EGO_COLLISION,
ACTIVE_FRONT_COLLISION,
ACTIVE_REAR_COLLISION,
ACTIVE_LATERAL_COLLISION,
} CollisionType;
Comment on lines +142 to +161
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enum doc comment references STATIONARY_PARTNER_COLLISION, but the actual enum value is named STATIONARY_AGENT_COLLISION. Please align the documentation and the enum name to avoid confusion for downstream users.

Copilot uses AI. Check for mistakes.

struct Agent {
int id;
int type;
Expand Down Expand Up @@ -195,6 +216,7 @@ struct Agent {
float goals_attempted_this_episode; // goals reached + last goal(if this segment can be judged as an attempt)
int current_goal_reached;
int collided_before_goal;
CollisionType current_collision_type;
float init_goal_x;
float init_goal_y;
float init_goal_z;
Expand Down
78 changes: 69 additions & 9 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@
// Offsets
#define COLLISION_RANGE 5
#define Z_RANGE 3
#define Z_BUFFER 4.0f // 4.0m buffer for different z level checking
#define SPEED_LIMIT 20.0f // Hardcoded speed limit value
#define COMFORT_JERK_THRESHOLD 5.0f // For JERK model comfort
#define COMFORT_ACCEL_THRESHOLD 3.0f // For JERK and CLASSIC model comfort
#define Z_BUFFER 4.0f // 4.0m buffer for different z level checking
#define SPEED_LIMIT 20.0f // Hardcoded speed limit value
#define COMFORT_JERK_THRESHOLD 5.0f // For JERK model comfort
#define COMFORT_ACCEL_THRESHOLD 3.0f // For JERK and CLASSIC model comfort
#define STATIONARY_SPEED_THRESHOLD 0.05f // Speed below which we consider the agent stationary
#define HEAD_ON_COLLISION_ANGLE_THRESHOLD 30.0f // Degrees within which a collision is considered head-on

// Metrics Heuristics
#define MIN_GOAL_SEGMENT_TIME_TO_ANALYZE_AGENT 1.0f
Expand Down Expand Up @@ -241,9 +243,11 @@ struct Log {
float total_distance_travelled;
float total_infractions;
float avg_speed_per_agent;
float max_observation_distance; // average max observation distance
float observation_coverage; // percentage of entities in obs window seen on average
float partner_obs_coverage; // % of partners within radius that fit in the obs slots
float max_observation_distance; // average max observation distance
float observation_coverage; // percentage of entities in obs window seen on average
float partner_obs_coverage; // % of partners within radius that fit in the obs slots
float at_fault_collision_count; // count of at-fault collisions (ratio computed in my_log)
float not_at_fault_collision_count; // count of not at-fault collisions (ratio computed in my_log)
};

typedef struct GridMapEntity GridMapEntity;
Expand Down Expand Up @@ -1504,6 +1508,52 @@ int collision_check(Drive *env, int agent_idx) {
return car_collided_with_index;
}

float dot_product(float v1_x, float v1_y, float v2_x, float v2_y) { return v1_x * v2_x + v1_y * v2_y; }

float get_partner_relative_angle(Drive *env, int ego_agent_idx, int partner_agent_idx) {
Agent *ego = &env->agents[ego_agent_idx];
Agent *partner = &env->agents[partner_agent_idx];

float dx = partner->sim_x - ego->sim_x;
float dy = partner->sim_y - ego->sim_y;
float norm = sqrtf(dx * dx + dy * dy);

float ego_x = cosf(ego->sim_heading);
float ego_y = sinf(ego->sim_heading);

float relative_angle = acosf(dot_product(dx, dy, ego_x, ego_y) / norm);

Comment thread
mpragnay marked this conversation as resolved.
return relative_angle;
}

void classify_collision_type(Drive *env, int ego_agent_idx, int collided_with_idx) {
Agent *ego = &env->agents[ego_agent_idx];
Agent *collided_entity = &env->agents[collided_with_idx];

float ego_speed_magnitude = sqrtf(ego->sim_vx * ego->sim_vx + ego->sim_vy * ego->sim_vy);
float partner_speed_magnitude =
sqrtf(collided_entity->sim_vx * collided_entity->sim_vx + collided_entity->sim_vy * collided_entity->sim_vy);

// Stationary cases
if (ego_speed_magnitude < STATIONARY_SPEED_THRESHOLD) {
ego->current_collision_type = STATIONARY_EGO_COLLISION;
return;
} else if (partner_speed_magnitude < STATIONARY_SPEED_THRESHOLD) {
ego->current_collision_type = STATIONARY_AGENT_COLLISION;
return;
}

// Both agents are moving
float relative_angle = get_partner_relative_angle(env, ego_agent_idx, collided_with_idx);
if (relative_angle < HEAD_ON_COLLISION_ANGLE_THRESHOLD * (M_PI / 180.0f)) {
ego->current_collision_type = ACTIVE_FRONT_COLLISION;
} else if (relative_angle > (180.0f - HEAD_ON_COLLISION_ANGLE_THRESHOLD) * (M_PI / 180.0f)) {
ego->current_collision_type = ACTIVE_REAR_COLLISION;
} else {
ego->current_collision_type = ACTIVE_LATERAL_COLLISION;
}
}

bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) {
if (fmax(p1[0], p2[0]) < fmin(q1[0], q2[0]) || fmin(p1[0], p2[0]) > fmax(q1[0], q2[0]) ||
fmax(p1[1], p2[1]) < fmin(q1[1], q2[1]) || fmin(p1[1], p2[1]) > fmax(q1[1], q2[1]))
Expand Down Expand Up @@ -2588,10 +2638,20 @@ void compute_agent_metrics(Drive *env, int agent_idx) {

// Check for vehicle collisions
int car_collided_with_index = collision_check(env, agent_idx);
if (car_collided_with_index != -1)
if (car_collided_with_index != -1) {
collided = VEHICLE_COLLISION;
agent->collision_state = collided;

agent->collision_state = collided;
classify_collision_type(env, agent_idx, car_collided_with_index);

Comment thread
mpragnay marked this conversation as resolved.
// Determine at-fault
CollisionType ct = env->agents[agent_idx].current_collision_type;
int at_fault =
(ct == STATIONARY_AGENT_COLLISION || ct == ACTIVE_FRONT_COLLISION || ct == ACTIVE_LATERAL_COLLISION);
int not_fault_ct = (ct == ACTIVE_REAR_COLLISION || ct == STATIONARY_EGO_COLLISION);
env->log.at_fault_collision_count += at_fault;
env->log.not_at_fault_collision_count += not_fault_ct;
}

if (collided == VEHICLE_COLLISION) {
if (env->collision_behavior == STOP_AGENT && !agent->stopped) {
Expand Down
Loading