From 15cb9db6b00296a64487c8bdbf1a54d70fa97f2a Mon Sep 17 00:00:00 2001 From: ry2009 <134240944+ry2009@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:41:06 -0500 Subject: [PATCH 1/2] Add vecenv fallback and fix batched forward state --- pufferlib/extensions/pufferlib.cpp | 257 ++++++++++++++++++++--------- pufferlib/extensions/vecenv.h | 1 - 2 files changed, 182 insertions(+), 76 deletions(-) diff --git a/pufferlib/extensions/pufferlib.cpp b/pufferlib/extensions/pufferlib.cpp index 64e09a59a..9cd2261f9 100644 --- a/pufferlib/extensions/pufferlib.cpp +++ b/pufferlib/extensions/pufferlib.cpp @@ -94,73 +94,96 @@ void clip_grad_norm_( std::tuple create_environments(int64_t num_envs, int threads) { + // Try to load native vecenv; fall back to dummy env if symbols are missing. void* handle = dlopen("./breakout.so", RTLD_NOW); - if (!handle) { - fprintf(stderr, "dlopen error: %s\n", dlerror()); - exit(1); - } - dlerror(); - - // Load the function pointer - create_envs = (create_environments_fn)dlsym(handle, "create_environments"); - env_init = (env_init_fn)dlsym(handle, "env_init"); - vec_reset = (vec_reset_fn)dlsym(handle, "vec_reset"); - vec_step = (vec_step_fn)dlsym(handle, "vec_step"); - env_close = (env_close_fn)dlsym(handle, "env_close"); - vec_close = (vec_close_fn)dlsym(handle, "vec_close"); - vec_log = (vec_log_fn)dlsym(handle, "vec_log"); - vec_render = (vec_render_fn)dlsym(handle, "vec_render"); - int obs_n = *(int*)dlsym(handle, "OBS_N"); - int act_n = *(int*)dlsym(handle, "ACT_N"); - int obs_t = *(int*)dlsym(handle, "OBS_T"); - int act_t = *(int*)dlsym(handle, "ACT_T"); - - const char* dlsym_error = dlerror(); - if (dlsym_error) { - fprintf(stderr, "dlsym error: %s\n", dlsym_error); - dlclose(handle); - exit(1); + bool loaded = handle != nullptr; + + if (loaded) { + dlerror(); + create_envs = (create_environments_fn)dlsym(handle, "create_environments"); + env_init = (env_init_fn)dlsym(handle, "env_init"); + vec_reset = (vec_reset_fn)dlsym(handle, "vec_reset"); + vec_step = (vec_step_fn)dlsym(handle, "vec_step"); + env_close = (env_close_fn)dlsym(handle, "env_close"); + vec_close = (vec_close_fn)dlsym(handle, "vec_close"); + vec_log = (vec_log_fn)dlsym(handle, "vec_log"); + vec_render = (vec_render_fn)dlsym(handle, "vec_render"); + + int* obs_n_ptr = (int*)dlsym(handle, "OBS_N"); + int* act_n_ptr = (int*)dlsym(handle, "ACT_N"); + int* obs_t_ptr = (int*)dlsym(handle, "OBS_T"); + int* act_t_ptr = (int*)dlsym(handle, "ACT_T"); + + const char* err = dlerror(); + if (err || !create_envs || !vec_step || !vec_reset || + !obs_n_ptr || !act_n_ptr || !obs_t_ptr || !act_t_ptr) { + fprintf(stderr, "[pufferlib] dlopen fallback: %s\n", err ? err : "missing symbol"); + loaded = false; + dlclose(handle); + } else { + int obs_n = *obs_n_ptr; + int act_n = *act_n_ptr; + int obs_t = *obs_t_ptr; + int act_t = *act_t_ptr; + + Dict* kwargs = create_dict(32); + dict_set_int(kwargs, "frameskip", 4); + dict_set_int(kwargs, "width", 576); + dict_set_int(kwargs, "height", 330); + dict_set_int(kwargs, "paddle_width", 62); + dict_set_int(kwargs, "paddle_height", 8); + dict_set_int(kwargs, "ball_width", 32); + dict_set_int(kwargs, "ball_height", 32); + dict_set_int(kwargs, "brick_width", 32); + dict_set_int(kwargs, "brick_height", 12); + dict_set_int(kwargs, "brick_rows", 6); + dict_set_int(kwargs, "brick_cols", 18); + dict_set_int(kwargs, "initial_ball_speed", 256); + dict_set_int(kwargs, "max_ball_speed", 448); + dict_set_int(kwargs, "paddle_speed", 620); + dict_set_int(kwargs, "continuous", 0); + + VecEnv* vec = create_envs(num_envs, threads, kwargs); + printf("Created VecEnv with %d environments\n", vec->size); + + auto obs_dtype = to_torch_dtype(obs_t); + auto atn_dtype = to_torch_dtype(act_t); + + auto obs = torch::from_blob(vec->observations, {num_envs, obs_n}, obs_dtype).pin_memory(); + auto actions = torch::from_blob(vec->actions, {num_envs}, atn_dtype).pin_memory(); + auto rewards = torch::from_blob(vec->rewards, {num_envs}, torch::kFloat32).pin_memory(); + auto terminals = torch::from_blob(vec->terminals, {num_envs}, torch::kUInt8).pin_memory(); + + vec_reset(vec); + return std::make_tuple(vec, obs, actions, rewards, terminals); + } } - Dict* kwargs = create_dict(32); - dict_set_int(kwargs, "frameskip", 4); - dict_set_int(kwargs, "width", 576); - dict_set_int(kwargs, "height", 330); - dict_set_int(kwargs, "paddle_width", 62); - dict_set_int(kwargs, "paddle_height", 8); - dict_set_int(kwargs, "ball_width", 32); - dict_set_int(kwargs, "ball_height", 32); - dict_set_int(kwargs, "brick_width", 32); - dict_set_int(kwargs, "brick_height", 12); - dict_set_int(kwargs, "brick_rows", 6); - dict_set_int(kwargs, "brick_cols", 18); - dict_set_int(kwargs, "initial_ball_speed", 256); - dict_set_int(kwargs, "max_ball_speed", 448); - dict_set_int(kwargs, "paddle_speed", 620); - dict_set_int(kwargs, "continuous", 0); - - /* - Dict* kwargs = create_dict(32); - dict_set_int(kwargs, "can_go_over_65536", 0); - dict_set_float(kwargs, "reward_scaler", 0.67); - dict_set_float(kwargs, "endgame_env_prob", 0.05); - dict_set_float(kwargs, "scaffolding_ratio", 0.67); - dict_set_int(kwargs, "use_heuristic_rewards", 1); - dict_set_float(kwargs, "snake_reward_weight", 0.0005); - dict_set_int(kwargs, "use_sparse_reward", 0); - */ - - VecEnv* vec = create_envs(num_envs, threads, kwargs); - printf("Created VecEnv with %d environments\n", vec->size); - - // Close the library - //dlclose(handle); - - auto obs_dtype = to_torch_dtype(obs_t); - auto atn_dtype = to_torch_dtype(act_t); - - auto obs = torch::from_blob(vec->observations, {num_envs, obs_n}, obs_dtype).pin_memory(); - auto actions = torch::from_blob(vec->actions, {num_envs}, atn_dtype).pin_memory(); + // Fallback: minimal CPU vecenv with dummy step/reset. + fprintf(stderr, "[pufferlib] Using dummy vecenv fallback (no breakout.so symbols)\n"); + + auto* vec = (VecEnv*)calloc(1, sizeof(VecEnv)); + vec->size = static_cast(num_envs); + int obs_n = 118; // matches policy input size + int act_n = 1; + + vec->observations = (float*)calloc(num_envs * obs_n, sizeof(float)); + vec->actions = (float*)calloc(num_envs * act_n, sizeof(float)); + vec->rewards = (float*)calloc(num_envs, sizeof(float)); + vec->terminals = (unsigned char*)calloc(num_envs, sizeof(unsigned char)); + + vec_reset = [](VecEnv* v) { + for (int i = 0; i < v->size * 118; i++) v->observations[i] = 0.001f * (float)(rand() % 23); + memset(v->rewards, 0, sizeof(float) * v->size); + memset(v->terminals, 0, sizeof(unsigned char) * v->size); + }; + vec_step = [](VecEnv* v) { + memset(v->rewards, 0, sizeof(float) * v->size); + memset(v->terminals, 0, sizeof(unsigned char) * v->size); + }; + + auto obs = torch::from_blob(vec->observations, {num_envs, obs_n}, torch::kFloat32).pin_memory(); + auto actions = torch::from_blob(vec->actions, {num_envs}, torch::kFloat32).pin_memory(); auto rewards = torch::from_blob(vec->rewards, {num_envs}, torch::kFloat32).pin_memory(); auto terminals = torch::from_blob(vec->terminals, {num_envs}, torch::kUInt8).pin_memory(); @@ -215,7 +238,8 @@ Log log_environments(torch::Tensor envs_tensor, torch::Tensor indices_tensor) { namespace py = pybind11; -// Forward declare modules +#ifndef PUFFERLIB_NO_CUDA +// Forward declare CUDA implementations (defined in modules.cu) torch::Tensor mingru_gate( torch::Tensor state, torch::Tensor gate, @@ -225,6 +249,11 @@ torch::autograd::tensor_list log_coeffs_and_values( torch::Tensor gate, torch::Tensor hidden ); +torch::autograd::tensor_list rmsnorm( + torch::Tensor x, + torch::Tensor weight, + double eps +); torch::autograd::tensor_list fused_scan( torch::Tensor log_coeffs, torch::Tensor log_values @@ -245,15 +274,93 @@ torch::autograd::tensor_list fused_ppo_loss( float vf_clip_coef, float vf_coef, float ent_coef - /* - torch::Tensor adv_mean, - torch::Tensor adv_std, - torch::Tensor clip_coef, - torch::Tensor vf_clip_coef, - torch::Tensor vf_coef, - torch::Tensor ent_coef - */ ); +#else +// CPU fallbacks so the extension builds without CUDA. +torch::Tensor mingru_gate( + torch::Tensor state, + torch::Tensor gate, + torch::Tensor hidden +) { + auto hidden_pos = torch::where(hidden >= 0, hidden + 0.5, torch::sigmoid(hidden)); + auto gate_sig = torch::sigmoid(gate); + return torch::lerp(state, hidden_pos, gate_sig); +} + +torch::autograd::tensor_list log_coeffs_and_values( + torch::Tensor gate, + torch::Tensor hidden +) { + auto log_coeffs = -torch::nn::functional::softplus(gate); + auto log_z = -torch::nn::functional::softplus(-gate); + auto relu_h = torch::relu(hidden); + auto log_tilde_h = torch::where(hidden >= 0, (relu_h + 0.5).log(), -torch::nn::functional::softplus(-hidden)); + auto log_values = log_z + log_tilde_h; + return {log_coeffs, log_values}; +} + +torch::autograd::tensor_list fused_scan( + torch::Tensor log_coeffs, + torch::Tensor log_values +) { + auto a_star = log_coeffs.cumsum(1); + auto log_h0_plus_b_star = (log_values - a_star).logcumsumexp(1); + auto log_h = a_star + log_h0_plus_b_star; + auto out = log_h.exp(); + return {out}; +} + +torch::Tensor logcumsumexp_cuda(torch::Tensor x) { + return x.logcumsumexp(1); +} + +torch::autograd::tensor_list rmsnorm( + torch::Tensor x, + torch::Tensor weight, + double eps +) { + auto out = torch::nn::functional::rms_norm(x, torch::nn::functional::RMSNormFuncOptions(x.size(-1)).weight(weight).eps(eps)); + return {out}; +} + +torch::autograd::tensor_list fused_ppo_loss( + torch::Tensor logits, + torch::Tensor values_pred, + torch::Tensor actions, + torch::Tensor old_logprobs, + torch::Tensor advantages, + torch::Tensor prio, + torch::Tensor values, + torch::Tensor returns, + float adv_mean, + float adv_std, + float clip_coef, + float vf_clip_coef, + float vf_coef, + float ent_coef +) { + // Pure-torch CPU fallback; autograd will handle backward. + auto logp = torch::log_softmax(logits, -1); + auto new_logp = logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1); + auto ratio = (new_logp - old_logprobs).exp(); + auto adv_norm = (advantages - adv_mean) / (adv_std + 1e-8); + + auto pg_loss1 = -prio * adv_norm * ratio; + auto pg_loss2 = -prio * adv_norm * ratio.clamp(1.0 - clip_coef, 1.0 + clip_coef); + auto pg_loss = torch::max(pg_loss1, pg_loss2); + + auto v_error = values_pred - values; + auto v_clipped = values + v_error.clamp(-vf_clip_coef, vf_clip_coef); + auto v_loss_unclipped = (values_pred - returns).pow(2); + auto v_loss_clipped = (v_clipped - returns).pow(2); + auto v_loss = 0.5 * torch::max(v_loss_unclipped, v_loss_clipped); + + auto entropy = -(logp * logp.exp()).sum(-1); + + auto loss = (pg_loss + vf_coef * v_loss - ent_coef * entropy).mean(); + return {loss.unsqueeze(0)}; +} +#endif /* torch::autograd::tensor_list rmsnorm( @@ -1126,7 +1233,7 @@ void batched_forward( float rng = static_cast(rand()) / static_cast(RAND_MAX); torch::Tensor mb_obs = observations.narrow(0, mb*minibatch_segments, minibatch_segments); torch::Tensor mb_state = torch::zeros( - {minibatch_segments, 1, policy->hidden_size}, + {policy->num_layers, minibatch_segments, 1, policy->hidden_size}, DTYPE ).to(device); auto [logits, newvalue] = policy->forward_train(mb_obs.to(DTYPE)+rng, mb_state+rng); @@ -1380,7 +1487,7 @@ PYBIND11_MODULE(_C, m) { m.def("log_coeffs_and_values", &log_coeffs_and_values); m.def("fused_scan", &fused_scan); m.def("fused_ppo_loss", &fused_ppo_loss); - //m.def("rmsnorm", &rmsnorm); + m.def("rmsnorm", &rmsnorm); /* py::class_>(m, "RMSNorm") diff --git a/pufferlib/extensions/vecenv.h b/pufferlib/extensions/vecenv.h index ee85296b3..df1022b08 100644 --- a/pufferlib/extensions/vecenv.h +++ b/pufferlib/extensions/vecenv.h @@ -2,7 +2,6 @@ #include #include #include -#include #define FLOAT 1 #define INT 2 From 77f1987b4e903ed0f35ad93d390f88d33f4e291a Mon Sep 17 00:00:00 2001 From: ry2009 <134240944+ry2009@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:52:13 -0500 Subject: [PATCH 2/2] Enable fused RMSNorm, wire fallback tests, and guard CPU --- pufferlib/extensions/cuda/kernels.cu | 268 +++++++++++++-------------- pufferlib/extensions/cuda/modules.cu | 22 +-- pufferlib/models.py | 24 ++- setup.py | 3 + test_kernels.py | 5 + tests/test_cpu_fallbacks.py | 71 +++++++ 6 files changed, 239 insertions(+), 154 deletions(-) create mode 100644 tests/test_cpu_fallbacks.py diff --git a/pufferlib/extensions/cuda/kernels.cu b/pufferlib/extensions/cuda/kernels.cu index 520812ced..885630913 100644 --- a/pufferlib/extensions/cuda/kernels.cu +++ b/pufferlib/extensions/cuda/kernels.cu @@ -17,142 +17,122 @@ inline int seq_size(int N) { return (N + SEQ_SIZE - 1) / SEQ_SIZE; } -// If you can get this to work, go ahead. I tried. -// NVCC won't parse templated types in kernel launches -/* -template