Skip to content
Open
3 changes: 3 additions & 0 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ multi_scenario_render_backend = egl
; Frequency of evaluation during training (in epochs)
eval_interval = 25
num_agents = 512
; Number of agents per environment in gigaflow eval mode
min_agents_per_env = 1
max_agents_per_env = 1
; Batch size for eval_multi_scenarios (number of scenarios per batch)
; Path to dataset used for evaluation
map_dir = "pufferlib/resources/drive/binaries/carla_py123d"
Expand Down
134 changes: 121 additions & 13 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
@@ -1,10 +1,67 @@
#include <Python.h>
#include "drive.h"
#define Env Drive
#define MY_SHARED
#define MY_PUT
#define MY_GET
#include "../env_binding.h"

// Process-local map cache: indexed by map_id, populated by my_shared(), used by my_init().
static SharedMapData **g_map_cache = NULL;
static int g_map_cache_size = 0;
static pid_t g_map_cache_pid = 0;
// Cache key: params that affect SharedMapData content (obs dists determine vision_range)
static float g_cache_road_obs_front_dist = 0;
static float g_cache_road_obs_behind_dist = 0;
static float g_cache_road_obs_side_dist = 0;
static char **g_cache_map_paths = NULL;

static void reset_cache_globals(void) {
g_map_cache = NULL;
g_map_cache_size = 0;
g_map_cache_pid = 0;
g_cache_road_obs_front_dist = 0;
g_cache_road_obs_behind_dist = 0;
g_cache_road_obs_side_dist = 0;
g_cache_map_paths = NULL;
}

static void release_map_cache_internal(void) {
if (g_map_cache == NULL)
return;
// After fork, child inherits g_map_cache pointers via copy-on-write.
// We must NOT free them — they belong to the parent's address space.
// Discard them and let the child rebuild its own cache on the next call.
if (g_map_cache_pid != 0 && g_map_cache_pid != getpid()) {
reset_cache_globals();
return;
}
Comment on lines +33 to +38
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fork-handling branch in release_map_cache_internal() resets the cache globals on PID mismatch without freeing g_map_cache / g_cache_map_paths or the SharedMapData entries. In a forked child this drops the only pointers to large allocations, effectively leaking them for the lifetime of the child process (and preventing the OS from reclaiming COW pages if the parent later releases them). Consider either (a) reusing the inherited cache after fork by updating g_map_cache_pid to getpid(), or (b) fully freeing/detaching the inherited cache in the child before resetting the globals.

Suggested change
// We must NOT free them — they belong to the parent's address space.
// Discard them and let the child rebuild its own cache on the next call.
if (g_map_cache_pid != 0 && g_map_cache_pid != getpid()) {
reset_cache_globals();
return;
}
// Adopt the inherited cache in the child so it can be released normally
// instead of dropping the only references to these allocations.
if (g_map_cache_pid != 0 && g_map_cache_pid != getpid())
g_map_cache_pid = getpid();

Copilot uses AI. Check for mistakes.
// Entries with live refs are detached: c_close will free them when ref_count reaches 0.
// Entries with no live refs are freed immediately.
for (int i = 0; i < g_map_cache_size; i++) {
if (g_map_cache[i] == NULL)
continue;
if (g_map_cache[i]->ref_count > 0)
g_map_cache[i]->detached = 1;
else
free_shared_map_data(g_map_cache[i]);
}
free(g_map_cache);
if (g_cache_map_paths != NULL) {
for (int i = 0; i < g_map_cache_size; i++)
free(g_cache_map_paths[i]);
free(g_cache_map_paths);
}
reset_cache_globals();
}

static PyObject *release_map_cache_py(PyObject *self __attribute__((unused)), PyObject *args __attribute__((unused))) {
release_map_cache_internal();
Py_RETURN_NONE;
}

#define MY_METHODS {"release_map_cache", release_map_cache_py, METH_VARARGS, "Release the shared map data cache"}

static int my_put(Env *env, PyObject *args, PyObject *kwargs) {
PyObject *obs = PyDict_GetItemString(kwargs, "observations");
if (!PyObject_TypeCheck(obs, &PyArray_Type)) {
Expand Down Expand Up @@ -1544,6 +1601,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
int max_agents_per_env = unpack(kwargs, "max_agents_per_env");
float goal_radius = (float)unpack(kwargs, "goal_radius");
int num_eval_scenarios = unpack(kwargs, "num_eval_scenarios");
float road_obs_front_dist = (float)unpack(kwargs, "road_obs_front_dist");
float road_obs_behind_dist = (float)unpack(kwargs, "road_obs_behind_dist");
float road_obs_side_dist = (float)unpack(kwargs, "road_obs_side_dist");
if (min_agents_per_env <= 0 || max_agents_per_env <= 0) {
PyErr_SetString(PyExc_ValueError, "min_agents_per_env and max_agents_per_env must be > 0");
return NULL;
Expand All @@ -1559,7 +1619,38 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {

srand(seed);

// GIGAFLOW mode: use random sampling for agent counts per env
// Reuse the existing cache if this process created it with matching config.
// Cache key: PID, num_maps, map file paths, obs dist params (determine vision_range).
int reuse_cache =
(g_map_cache != NULL && g_map_cache_pid == getpid() && g_map_cache_size == num_maps &&
g_cache_road_obs_front_dist == road_obs_front_dist && g_cache_road_obs_behind_dist == road_obs_behind_dist &&
g_cache_road_obs_side_dist == road_obs_side_dist);
if (reuse_cache && g_cache_map_paths != NULL) {
for (int i = 0; i < num_maps; i++) {
const char *path = PyUnicode_AsUTF8(PyList_GetItem(map_files, i));
if (g_cache_map_paths[i] == NULL || strcmp(g_cache_map_paths[i], path) != 0) {
reuse_cache = 0;
break;
}
}
}
if (!reuse_cache) {
release_map_cache_internal();
g_map_cache_size = num_maps;
g_map_cache = (SharedMapData **)calloc(num_maps, sizeof(SharedMapData *));
g_map_cache_pid = getpid();
g_cache_road_obs_front_dist = road_obs_front_dist;
g_cache_road_obs_behind_dist = road_obs_behind_dist;
g_cache_road_obs_side_dist = road_obs_side_dist;
g_cache_map_paths = (char **)calloc(num_maps, sizeof(char *));
for (int i = 0; i < num_maps; i++) {
const char *path = PyUnicode_AsUTF8(PyList_GetItem(map_files, i));
g_cache_map_paths[i] = strdup(path);
}
}

// GIGAFLOW mode: agent counts are numeric, no binary loading needed for counting.
// We do lazily populate the cache so my_init can call init_from_shared.
if (simulation_mode == SIMULATION_GIGAFLOW) {
if (eval_mode) {
// Eval mode: fixed agent count, sequential map cycling
Expand All @@ -1573,10 +1664,17 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {

int offset = 0;
for (int i = 0; i < env_count; i++) {
int map_id_g = (s_map_counter + i) % num_maps;
PyList_SetItem(agent_offsets, i, PyLong_FromLong(offset));
PyList_SetItem(map_ids_list, i, PyLong_FromLong((s_map_counter + i) % num_maps));
PyList_SetItem(map_ids_list, i, PyLong_FromLong(map_id_g));
int remaining = num_agents - offset;
offset += (remaining < agents_per_env) ? remaining : agents_per_env;
// Lazily populate cache for assigned map
if (g_map_cache[map_id_g] == NULL) {
const char *map_file_path = PyUnicode_AsUTF8(PyList_GetItem(map_files, map_id_g));
g_map_cache[map_id_g] = create_shared_map_data(map_file_path, road_obs_front_dist,
road_obs_behind_dist, road_obs_side_dist);
}
}
PyList_SetItem(agent_offsets, env_count, PyLong_FromLong(offset));

Expand All @@ -1597,23 +1695,13 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
if (remaining <= max_agents_per_env) {
count = remaining;
} else {
// 1. We must leave at least min_agents_per_env for the future.
int absolute_max_allowed = remaining - min_agents_per_env;

// 2. We cannot take more than max_agents_per_env right now.
int current_upper_bound =
(absolute_max_allowed < max_agents_per_env) ? absolute_max_allowed : max_agents_per_env;

// 3. We must take at least min_agents_per_env right now.
int current_lower_bound = min_agents_per_env;

// Safety check: if constraints are tight, lower might equal upper.
// If absolute_max_allowed < min_lower_bound for example leading to
// current_upper_bound < current_lower_bound
if (current_upper_bound <= current_lower_bound) {
count = current_lower_bound;
} else {
// Now the range is guaranteed to be positive.
int range = current_upper_bound - current_lower_bound + 1;
count = current_lower_bound + (rand() % range);
}
Expand All @@ -1628,9 +1716,16 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {

int offset = 0;
for (int i = 0; i < env_count; i++) {
int map_id_g = rand() % num_maps;
PyList_SetItem(agent_offsets, i, PyLong_FromLong(offset));
PyList_SetItem(map_ids_list, i, PyLong_FromLong(rand() % num_maps));
PyList_SetItem(map_ids_list, i, PyLong_FromLong(map_id_g));
offset += agent_counts[i];
// Lazily populate cache for assigned map
if (g_map_cache[map_id_g] == NULL) {
const char *map_file_path = PyUnicode_AsUTF8(PyList_GetItem(map_files, map_id_g));
g_map_cache[map_id_g] = create_shared_map_data(map_file_path, road_obs_front_dist, road_obs_behind_dist,
road_obs_side_dist);
}
}
PyList_SetItem(agent_offsets, env_count, PyLong_FromLong(num_agents));

Expand Down Expand Up @@ -1810,6 +1905,19 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->phantom_braking_trigger_prob = (float)unpack(kwargs, "phantom_braking_trigger_prob");
env->phantom_braking_duration = (int)unpack(kwargs, "phantom_braking_duration");

// Use shared map cache if map_id is provided and cache entry exists
PyObject *map_id_obj = kwargs ? PyDict_GetItemString(kwargs, "map_id") : NULL;
if (map_id_obj != NULL && g_map_cache != NULL) {
int map_id = (int)PyLong_AsLong(map_id_obj);
if (map_id >= 0 && map_id < g_map_cache_size && g_map_cache[map_id] != NULL) {
init_from_shared(env, g_map_cache[map_id]);
return 0;
}
// Cache miss: warn and fall through to disk loading
fprintf(stderr, "WARNING: map_id=%d provided but shared map cache miss — loading from disk\n", map_id);
}

// Fallback: load map from disk (standalone use, tests, or cache miss)
init(env);
return 0;
}
Expand Down
Loading
Loading