2020os .environ .setdefault ("GOMEMLIMIT" , "2GiB" )
2121
2222import asyncio
23- import gc
2423import logging
2524import time
2625import uuid
2928
3029import torch
3130
32- # Optional memory monitoring
33- try :
34- import psutil
35-
36- HAS_PSUTIL = True
37- except ImportError :
38- HAS_PSUTIL = False
39-
4031# Configure logging to see INFO level messages
4132logging .basicConfig (
4233 level = logging .INFO , format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
4334)
4435import torch .nn .functional as F
4536import torchstore as ts
4637from datasets import load_dataset
47- from envs .coding_env import CodeAction , CodingEnv
4838from forge .actors ._torchstore_utils import (
4939 get_dcp_whole_state_dict_key ,
5040 get_param_prefix ,
5141)
5242from forge .actors .generator import Generator
53- from forge .actors .generic_openenv import GenericOpenEnvActor
43+
44+ # from forge.actors.podman_coder import PodmanPythonCoder
45+ from forge .actors .openenv_coder import OpenEnvCoder
46+
5447from forge .actors .reference_model import ReferenceModel
5548from forge .actors .replay_buffer import ReplayBuffer
5649from forge .actors .trainer import RLTrainer
@@ -178,7 +171,7 @@ def simple_grpo_loss(
178171 ref_logprobs : torch .Tensor ,
179172 advantages : torch .Tensor ,
180173 padding_mask : torch .Tensor ,
181- beta : float = 0.01 ,
174+ beta : float = 0.001 ,
182175) -> torch .Tensor :
183176 """
184177 GRPO Loss Function for on-policy samples with numerical stability improvements
@@ -196,14 +189,41 @@ def simple_grpo_loss(
196189 ref_logprobs is ONLY used for the KL penalty, not the policy ratio.
197190 """
198191 logprobs : torch .Tensor = compute_logprobs (logits , response )
199- kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
192+
193+ # Check for NaN/Inf in logprobs
194+ if torch .isnan (logprobs ).any () or torch .isinf (logprobs ).any ():
195+ print ("WARNING: NaN/Inf detected in logprobs!" )
196+ logprobs = torch .nan_to_num (logprobs , nan = 0.0 , posinf = 0.0 , neginf = - 100.0 )
197+
198+ # ✅ CORRECT: On-policy REINFORCE gradient
199+ # This gives gradient: -A · ∇log p_current
200+ # Forward value: 1.0 * advantages (since exp(0) = 1)
201+ # Backward gradient: advantages · ∇log p_current
200202 per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
203+
204+ # ✅ KL divergence penalty with numerical stability
205+ # Clamp to prevent extreme values while allowing meaningful divergence
206+ delta = (ref_logprobs - logprobs ).clamp (- 10 , 10 )
207+ kl = torch .exp (delta ) - delta - 1
208+
209+ # ✅ Clamp KL to prevent extreme values
210+ kl = kl .clamp (- 20 , 20 )
211+
201212 per_token_loss = - (per_token_policy_loss - beta * kl )
213+
214+ # ✅ Loss clamping as a safety measure
215+ per_token_loss = per_token_loss .clamp (- 100 , 100 )
216+
202217 loss = (
203218 ((per_token_loss * padding_mask ).sum (dim = 1 ))
204219 / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
205220 ).mean ()
206221
222+ # Check for NaN/Inf in final loss
223+ if torch .isnan (loss ) or torch .isinf (loss ):
224+ print ("WARNING: NaN/Inf detected in final loss!" )
225+ loss = torch .tensor (0.0 , device = loss .device , requires_grad = True )
226+
207227 # ✅ Enhanced logging for debugging
208228 record_metric ("loss/policy_loss" , per_token_policy_loss .mean ().item (), Reduce .MEAN )
209229 record_metric ("loss/kl_penalty" , (beta * kl ).mean ().item (), Reduce .MEAN )
@@ -215,6 +235,7 @@ def simple_grpo_loss(
215235 record_metric ("loss/per_token_loss_max" , per_token_loss .max ().item (), Reduce .MEAN )
216236 record_metric ("loss/logprobs_mean" , logprobs .mean ().item (), Reduce .MEAN )
217237 record_metric ("loss/ref_logprobs_mean" , ref_logprobs .mean ().item (), Reduce .MEAN )
238+ record_metric ("loss/delta_mean" , delta .mean ().item (), Reduce .MEAN )
218239
219240 return loss
220241
@@ -366,28 +387,52 @@ def setup(self):
366387
367388 def get_coding_system_prompt ():
368389 """Get system prompt for coding tasks."""
390+ return """You are an expert Python programmer who writes clean, efficient, and well-tested code.
391+
392+ Given a problem description, write a Python function that solves it following these guidelines:
393+
394+ **CODE REQUIREMENTS:**
395+ 1. **Write clean and efficient code**: Use clear variable names, proper structure, and Pythonic idioms
396+ 2. **Include comprehensive docstrings**: Explain what the function does, parameters, return values, and any important notes
397+ 3. **Handle edge cases**: Consider and appropriately handle boundary conditions and potential errors
398+ 4. **Ensure correctness**: Your solution should be robust and handle all requirements
369399
370- return """You are an expert Python programmer who writes clean, efficient, and well-tested code.
400+ **CRITICAL RESTRICTIONS (Your code WILL FAIL if you violate these):**
371401
372- Given a problem description, write a Python function that solves it following these guidelines:
402+ **FORBIDDEN KEYWORDS:**
403+ - NO `global` keyword (use function parameters/returns instead)
404+ - NO `yield` keyword (no generators, use lists instead)
405+ - NO `nonlocal` keyword (restructure your code to avoid it)
373406
374- **CODE REQUIREMENTS:**
375- 1. **Write clean and efficient code**: Use clear variable names, proper structure, and Pythonic idioms
376- 2. **Include comprehensive docstrings**: Explain what the function does, parameters, return values, and any important notes
377- 3. **Handle edge cases**: Consider and appropriately handle boundary conditions and potential errors
378- 4. **Ensure correctness**: Your solution should be robust and handle all requirements
407+ **FORBIDDEN OPERATIONS:**
408+ - NO dunder attributes: `__dict__`, `__name__`, `__code__`, etc.
409+ - NO dunder methods: `__contains__()`, etc. (use `in` operator instead)
410+ - NO `input()` function (all inputs come from function parameters)
411+ - NO `locals()` or `globals()` functions
412+ - NO nested class definitions
379413
414+ **FILE OPERATIONS:**
415+ - Use `pathlib` for file paths, NOT `os.path`
416+ - Example: `from pathlib import Path; p = Path('/path/to/file')`
380417
418+ **ALLOWED STANDARD LIBRARY IMPORTS:**
419+ - Core: sys, os, functools, typing, math, random, time, datetime, re, collections, itertools, statistics
420+ - Data: json, csv, struct, base64, dataclasses, copy, heapq, enum
421+ - Strings: string, ast, unicodedata
422+ - Advanced: abc, contextlib, inspect, secrets, uuid, pathlib, io
423+ - Async/Threading: threading, asyncio, concurrent.futures
424+ - Network: socket, urllib.parse
381425
382- **FORMAT YOUR RESPONSE AS:**
426+ **FORMAT YOUR RESPONSE AS:**
383427
384- ```python
385- def function_name(parameters):
386- \" \" \" Comprehensive docstring explaining the function.\" \" \"
387- # Implementation here
388- pass
389- ```
390- """
428+ ```python
429+ def function_name(parameters):
430+ \" \" \" Comprehensive docstring explaining the function.\" \" \"
431+ # Implementation here
432+ pass
433+ ```
434+
435+ Provide the final, working solution. Focus on correctness, readability, and efficiency."""
391436
392437 def transform_sample (sample ):
393438 # AceCode format with OSS filtering
@@ -505,33 +550,56 @@ async def main(cfg: DictConfig):
505550
506551 # ---- Setup services ---- #
507552
508- # Setup coding environment using GenericOpenEnvActor with CodingEnv
509- # This actor provides a sandboxed Python execution environment via OpenEnv.
510- # Get docker image and env vars from config, with sensible defaults
511- coding_env_config = cfg .get ("coding_env" , {})
512- docker_image = coding_env_config .get ("docker_image" , "coding-env:latest" )
513- additional_imports = coding_env_config .get (
514- "additional_imports" , ["sys" , "os" , "functools" , "typing" ]
515- )
516- env_vars = {"PYTHON_ADDITIONAL_IMPORTS" : "," .join (additional_imports )}
517- container_timeout_s = coding_env_config .get ("container_timeout_s" , 180.0 )
518- request_timeout_s = coding_env_config .get ("request_timeout_s" , 120.0 )
519- container_memory_gb = coding_env_config .get ("container_memory_gb" , 4 )
520-
521- coder_actor = await GenericOpenEnvActor .as_actor (
522- env_class = CodingEnv ,
523- action_class = CodeAction ,
524- docker_image = docker_image ,
525- env_vars = env_vars ,
526- container_timeout_s = container_timeout_s ,
527- request_timeout_s = request_timeout_s ,
528- container_memory_gb = container_memory_gb ,
529- enable_zombie_cleanup = True , # Enable for code execution environments
553+ # Setup coding environment with comprehensive standard library imports
554+ # Based on analysis of 143 numpy, 47 requests, 35 urllib.parse, 31 socket, 31 dataclasses import failures
555+ coder_actor = await OpenEnvCoder .as_actor (
556+ additional_imports = [
557+ # Core (default)
558+ "sys" ,
559+ "os" ,
560+ "functools" ,
561+ "typing" ,
562+ # Data Science & Numerical
563+ "numpy" ,
564+ "pandas" ,
565+ # Data Structures & Collections (31 dataclasses, 22 copy, 19 heapq, 17 enum)
566+ "dataclasses" ,
567+ "copy" ,
568+ "heapq" ,
569+ "enum" ,
570+ # String & Text Processing (22 string, 21 ast)
571+ "string" ,
572+ "ast" ,
573+ # Data Formats & Serialization (25 json, 15 struct, 10 base64, 5 csv)
574+ "json" ,
575+ "struct" ,
576+ "base64" ,
577+ "csv" ,
578+ # Math & Numbers (12 cmath)
579+ "cmath" ,
580+ # Abstract Base Classes & Patterns (16 abc, 7 contextlib, 7 inspect)
581+ "abc" ,
582+ "contextlib" ,
583+ "inspect" ,
584+ # Security & Utilities (16 secrets, 4 uuid)
585+ "secrets" ,
586+ "uuid" ,
587+ # I/O & Path Operations (6 pathlib, 5 io)
588+ "pathlib" ,
589+ "io" ,
590+ # Async & Concurrency (11 threading, 6 asyncio, 3 concurrent.futures)
591+ "threading" ,
592+ "asyncio" ,
593+ "concurrent.futures" ,
594+ # Network & Web (35 urllib.parse, 31 socket)
595+ "urllib.parse" ,
596+ "socket" ,
597+ ]
530598 )
531599
532600 # Setup coding reward functions
533601 ground_truth_reward = GroundTruthTestReward (coder_actor )
534- thinking_reward = ThinkingReward ()
602+ # thinking_reward = ThinkingReward()
535603
536604 (
537605 dataloader ,
@@ -553,7 +621,7 @@ async def main(cfg: DictConfig):
553621 ComputeAdvantages .options (** cfg .actors .compute_advantages ).as_actor (),
554622 ReferenceModel .options (** cfg .services .ref_model ).as_service (** cfg .ref_model ),
555623 RewardActor .options (** cfg .services .reward_actor ).as_service (
556- reward_functions = [ground_truth_reward , thinking_reward ]
624+ reward_functions = [ground_truth_reward ]
557625 ),
558626 )
559627
@@ -642,36 +710,6 @@ async def continuous_rollouts():
642710 "main/continuous_rollouts/count_rollout_iterations" , 1 , Reduce .SUM
643711 )
644712 t .stop ()
645-
646- # CRITICAL: Explicit memory cleanup to prevent leaks
647- # Clear tensor references
648- del episodes , advantages , responses
649- # Clear CUDA cache if using GPU
650- if torch .cuda .is_available ():
651- torch .cuda .empty_cache ()
652- # Force garbage collection every rollout to prevent accumulation
653- gc .collect ()
654-
655- # CRITICAL: Clean up zombie processes periodically (every 5 rollouts)
656- # This prevents accumulation of timed-out processes consuming memory
657- if rollout_count % 5 == 0 :
658- killed_count = await coder_actor .cleanup_zombie_processes .call_one ()
659- if killed_count > 0 :
660- print (
661- f"Rollout { rollout_count } : Cleaned up { killed_count } zombie processes"
662- )
663- record_metric (
664- "memory/zombie_processes_killed" , killed_count , Reduce .SUM
665- )
666-
667- # Log memory usage periodically (every 10 rollouts) if psutil is available
668- if rollout_count % 10 == 0 and HAS_PSUTIL :
669- process = psutil .Process ()
670- memory_mb = process .memory_info ().rss / 1024 / 1024
671- record_metric ("memory/process_memory_mb" , memory_mb , Reduce .MEAN )
672- print (
673- f"Rollout { rollout_count } : Process memory = { memory_mb :.2f} MB"
674- )
675713 except RuntimeError as e :
676714 error_msg = str (e ).lower ()
677715 # Check if this is a container-related error that couldn't be recovered
@@ -759,7 +797,7 @@ async def continuous_training():
759797 except KeyboardInterrupt :
760798 print ("Training interrupted by user" )
761799 finally :
762- print ("Shutting down... (this may take a few seconds) " )
800+ print ("Shutting down..." )
763801 shutdown_event .set ()
764802
765803 try :
@@ -784,6 +822,8 @@ async def continuous_training():
784822 @parse
785823 def _main (cfg ):
786824 """Main entry point for GRPO training."""
825+ os .environ ["NCCL_ASYNC_ERROR_HANDLING" ] = "1"
826+ os .environ ["NCCL_TIMEOUT_MS" ] = "60000" # 60 second timeout
787827 os .environ ["MONARCH_HOSTMESH_V1" ] = "1"
788828 os .environ ["TORCHSTORE_RDMA_ENABLED" ] = "1"
789829 # os.environ["FORGE_DISABLE_METRICS"] = "1"
0 commit comments