Skip to content

Commit 536e0c2

Browse files
committed
third checkpoint
1 parent 51f6138 commit 536e0c2

16 files changed

+823
-200
lines changed

apps/coding-grpo/main.py

Lines changed: 123 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
os.environ.setdefault("GOMEMLIMIT", "2GiB")
2121

2222
import asyncio
23-
import gc
2423
import logging
2524
import time
2625
import uuid
@@ -29,28 +28,22 @@
2928

3029
import torch
3130

32-
# Optional memory monitoring
33-
try:
34-
import psutil
35-
36-
HAS_PSUTIL = True
37-
except ImportError:
38-
HAS_PSUTIL = False
39-
4031
# Configure logging to see INFO level messages
4132
logging.basicConfig(
4233
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
4334
)
4435
import torch.nn.functional as F
4536
import torchstore as ts
4637
from datasets import load_dataset
47-
from envs.coding_env import CodeAction, CodingEnv
4838
from forge.actors._torchstore_utils import (
4939
get_dcp_whole_state_dict_key,
5040
get_param_prefix,
5141
)
5242
from forge.actors.generator import Generator
53-
from forge.actors.generic_openenv import GenericOpenEnvActor
43+
44+
# from forge.actors.podman_coder import PodmanPythonCoder
45+
from forge.actors.openenv_coder import OpenEnvCoder
46+
5447
from forge.actors.reference_model import ReferenceModel
5548
from forge.actors.replay_buffer import ReplayBuffer
5649
from forge.actors.trainer import RLTrainer
@@ -178,7 +171,7 @@ def simple_grpo_loss(
178171
ref_logprobs: torch.Tensor,
179172
advantages: torch.Tensor,
180173
padding_mask: torch.Tensor,
181-
beta: float = 0.01,
174+
beta: float = 0.001,
182175
) -> torch.Tensor:
183176
"""
184177
GRPO Loss Function for on-policy samples with numerical stability improvements
@@ -196,14 +189,41 @@ def simple_grpo_loss(
196189
ref_logprobs is ONLY used for the KL penalty, not the policy ratio.
197190
"""
198191
logprobs: torch.Tensor = compute_logprobs(logits, response)
199-
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
192+
193+
# Check for NaN/Inf in logprobs
194+
if torch.isnan(logprobs).any() or torch.isinf(logprobs).any():
195+
print("WARNING: NaN/Inf detected in logprobs!")
196+
logprobs = torch.nan_to_num(logprobs, nan=0.0, posinf=0.0, neginf=-100.0)
197+
198+
# ✅ CORRECT: On-policy REINFORCE gradient
199+
# This gives gradient: -A · ∇log p_current
200+
# Forward value: 1.0 * advantages (since exp(0) = 1)
201+
# Backward gradient: advantages · ∇log p_current
200202
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
203+
204+
# ✅ KL divergence penalty with numerical stability
205+
# Clamp to prevent extreme values while allowing meaningful divergence
206+
delta = (ref_logprobs - logprobs).clamp(-10, 10)
207+
kl = torch.exp(delta) - delta - 1
208+
209+
# ✅ Clamp KL to prevent extreme values
210+
kl = kl.clamp(-20, 20)
211+
201212
per_token_loss = -(per_token_policy_loss - beta * kl)
213+
214+
# ✅ Loss clamping as a safety measure
215+
per_token_loss = per_token_loss.clamp(-100, 100)
216+
202217
loss = (
203218
((per_token_loss * padding_mask).sum(dim=1))
204219
/ (padding_mask.sum(dim=1).clamp(min=1.0))
205220
).mean()
206221

222+
# Check for NaN/Inf in final loss
223+
if torch.isnan(loss) or torch.isinf(loss):
224+
print("WARNING: NaN/Inf detected in final loss!")
225+
loss = torch.tensor(0.0, device=loss.device, requires_grad=True)
226+
207227
# ✅ Enhanced logging for debugging
208228
record_metric("loss/policy_loss", per_token_policy_loss.mean().item(), Reduce.MEAN)
209229
record_metric("loss/kl_penalty", (beta * kl).mean().item(), Reduce.MEAN)
@@ -215,6 +235,7 @@ def simple_grpo_loss(
215235
record_metric("loss/per_token_loss_max", per_token_loss.max().item(), Reduce.MEAN)
216236
record_metric("loss/logprobs_mean", logprobs.mean().item(), Reduce.MEAN)
217237
record_metric("loss/ref_logprobs_mean", ref_logprobs.mean().item(), Reduce.MEAN)
238+
record_metric("loss/delta_mean", delta.mean().item(), Reduce.MEAN)
218239

219240
return loss
220241

@@ -366,28 +387,52 @@ def setup(self):
366387

367388
def get_coding_system_prompt():
368389
"""Get system prompt for coding tasks."""
390+
return """You are an expert Python programmer who writes clean, efficient, and well-tested code.
391+
392+
Given a problem description, write a Python function that solves it following these guidelines:
393+
394+
**CODE REQUIREMENTS:**
395+
1. **Write clean and efficient code**: Use clear variable names, proper structure, and Pythonic idioms
396+
2. **Include comprehensive docstrings**: Explain what the function does, parameters, return values, and any important notes
397+
3. **Handle edge cases**: Consider and appropriately handle boundary conditions and potential errors
398+
4. **Ensure correctness**: Your solution should be robust and handle all requirements
369399
370-
return """You are an expert Python programmer who writes clean, efficient, and well-tested code.
400+
**CRITICAL RESTRICTIONS (Your code WILL FAIL if you violate these):**
371401
372-
Given a problem description, write a Python function that solves it following these guidelines:
402+
**FORBIDDEN KEYWORDS:**
403+
- NO `global` keyword (use function parameters/returns instead)
404+
- NO `yield` keyword (no generators, use lists instead)
405+
- NO `nonlocal` keyword (restructure your code to avoid it)
373406
374-
**CODE REQUIREMENTS:**
375-
1. **Write clean and efficient code**: Use clear variable names, proper structure, and Pythonic idioms
376-
2. **Include comprehensive docstrings**: Explain what the function does, parameters, return values, and any important notes
377-
3. **Handle edge cases**: Consider and appropriately handle boundary conditions and potential errors
378-
4. **Ensure correctness**: Your solution should be robust and handle all requirements
407+
**FORBIDDEN OPERATIONS:**
408+
- NO dunder attributes: `__dict__`, `__name__`, `__code__`, etc.
409+
- NO dunder methods: `__contains__()`, etc. (use `in` operator instead)
410+
- NO `input()` function (all inputs come from function parameters)
411+
- NO `locals()` or `globals()` functions
412+
- NO nested class definitions
379413
414+
**FILE OPERATIONS:**
415+
- Use `pathlib` for file paths, NOT `os.path`
416+
- Example: `from pathlib import Path; p = Path('/path/to/file')`
380417
418+
**ALLOWED STANDARD LIBRARY IMPORTS:**
419+
- Core: sys, os, functools, typing, math, random, time, datetime, re, collections, itertools, statistics
420+
- Data: json, csv, struct, base64, dataclasses, copy, heapq, enum
421+
- Strings: string, ast, unicodedata
422+
- Advanced: abc, contextlib, inspect, secrets, uuid, pathlib, io
423+
- Async/Threading: threading, asyncio, concurrent.futures
424+
- Network: socket, urllib.parse
381425
382-
**FORMAT YOUR RESPONSE AS:**
426+
**FORMAT YOUR RESPONSE AS:**
383427
384-
```python
385-
def function_name(parameters):
386-
\"\"\"Comprehensive docstring explaining the function.\"\"\"
387-
# Implementation here
388-
pass
389-
```
390-
"""
428+
```python
429+
def function_name(parameters):
430+
\"\"\"Comprehensive docstring explaining the function.\"\"\"
431+
# Implementation here
432+
pass
433+
```
434+
435+
Provide the final, working solution. Focus on correctness, readability, and efficiency."""
391436

392437
def transform_sample(sample):
393438
# AceCode format with OSS filtering
@@ -505,33 +550,56 @@ async def main(cfg: DictConfig):
505550

506551
# ---- Setup services ---- #
507552

508-
# Setup coding environment using GenericOpenEnvActor with CodingEnv
509-
# This actor provides a sandboxed Python execution environment via OpenEnv.
510-
# Get docker image and env vars from config, with sensible defaults
511-
coding_env_config = cfg.get("coding_env", {})
512-
docker_image = coding_env_config.get("docker_image", "coding-env:latest")
513-
additional_imports = coding_env_config.get(
514-
"additional_imports", ["sys", "os", "functools", "typing"]
515-
)
516-
env_vars = {"PYTHON_ADDITIONAL_IMPORTS": ",".join(additional_imports)}
517-
container_timeout_s = coding_env_config.get("container_timeout_s", 180.0)
518-
request_timeout_s = coding_env_config.get("request_timeout_s", 120.0)
519-
container_memory_gb = coding_env_config.get("container_memory_gb", 4)
520-
521-
coder_actor = await GenericOpenEnvActor.as_actor(
522-
env_class=CodingEnv,
523-
action_class=CodeAction,
524-
docker_image=docker_image,
525-
env_vars=env_vars,
526-
container_timeout_s=container_timeout_s,
527-
request_timeout_s=request_timeout_s,
528-
container_memory_gb=container_memory_gb,
529-
enable_zombie_cleanup=True, # Enable for code execution environments
553+
# Setup coding environment with comprehensive standard library imports
554+
# Based on analysis of 143 numpy, 47 requests, 35 urllib.parse, 31 socket, 31 dataclasses import failures
555+
coder_actor = await OpenEnvCoder.as_actor(
556+
additional_imports=[
557+
# Core (default)
558+
"sys",
559+
"os",
560+
"functools",
561+
"typing",
562+
# Data Science & Numerical
563+
"numpy",
564+
"pandas",
565+
# Data Structures & Collections (31 dataclasses, 22 copy, 19 heapq, 17 enum)
566+
"dataclasses",
567+
"copy",
568+
"heapq",
569+
"enum",
570+
# String & Text Processing (22 string, 21 ast)
571+
"string",
572+
"ast",
573+
# Data Formats & Serialization (25 json, 15 struct, 10 base64, 5 csv)
574+
"json",
575+
"struct",
576+
"base64",
577+
"csv",
578+
# Math & Numbers (12 cmath)
579+
"cmath",
580+
# Abstract Base Classes & Patterns (16 abc, 7 contextlib, 7 inspect)
581+
"abc",
582+
"contextlib",
583+
"inspect",
584+
# Security & Utilities (16 secrets, 4 uuid)
585+
"secrets",
586+
"uuid",
587+
# I/O & Path Operations (6 pathlib, 5 io)
588+
"pathlib",
589+
"io",
590+
# Async & Concurrency (11 threading, 6 asyncio, 3 concurrent.futures)
591+
"threading",
592+
"asyncio",
593+
"concurrent.futures",
594+
# Network & Web (35 urllib.parse, 31 socket)
595+
"urllib.parse",
596+
"socket",
597+
]
530598
)
531599

532600
# Setup coding reward functions
533601
ground_truth_reward = GroundTruthTestReward(coder_actor)
534-
thinking_reward = ThinkingReward()
602+
# thinking_reward = ThinkingReward()
535603

536604
(
537605
dataloader,
@@ -553,7 +621,7 @@ async def main(cfg: DictConfig):
553621
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
554622
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
555623
RewardActor.options(**cfg.services.reward_actor).as_service(
556-
reward_functions=[ground_truth_reward, thinking_reward]
624+
reward_functions=[ground_truth_reward]
557625
),
558626
)
559627

@@ -642,36 +710,6 @@ async def continuous_rollouts():
642710
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
643711
)
644712
t.stop()
645-
646-
# CRITICAL: Explicit memory cleanup to prevent leaks
647-
# Clear tensor references
648-
del episodes, advantages, responses
649-
# Clear CUDA cache if using GPU
650-
if torch.cuda.is_available():
651-
torch.cuda.empty_cache()
652-
# Force garbage collection every rollout to prevent accumulation
653-
gc.collect()
654-
655-
# CRITICAL: Clean up zombie processes periodically (every 5 rollouts)
656-
# This prevents accumulation of timed-out processes consuming memory
657-
if rollout_count % 5 == 0:
658-
killed_count = await coder_actor.cleanup_zombie_processes.call_one()
659-
if killed_count > 0:
660-
print(
661-
f"Rollout {rollout_count}: Cleaned up {killed_count} zombie processes"
662-
)
663-
record_metric(
664-
"memory/zombie_processes_killed", killed_count, Reduce.SUM
665-
)
666-
667-
# Log memory usage periodically (every 10 rollouts) if psutil is available
668-
if rollout_count % 10 == 0 and HAS_PSUTIL:
669-
process = psutil.Process()
670-
memory_mb = process.memory_info().rss / 1024 / 1024
671-
record_metric("memory/process_memory_mb", memory_mb, Reduce.MEAN)
672-
print(
673-
f"Rollout {rollout_count}: Process memory = {memory_mb:.2f} MB"
674-
)
675713
except RuntimeError as e:
676714
error_msg = str(e).lower()
677715
# Check if this is a container-related error that couldn't be recovered
@@ -759,7 +797,7 @@ async def continuous_training():
759797
except KeyboardInterrupt:
760798
print("Training interrupted by user")
761799
finally:
762-
print("Shutting down... (this may take a few seconds)")
800+
print("Shutting down...")
763801
shutdown_event.set()
764802

765803
try:
@@ -784,6 +822,8 @@ async def continuous_training():
784822
@parse
785823
def _main(cfg):
786824
"""Main entry point for GRPO training."""
825+
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
826+
os.environ["NCCL_TIMEOUT_MS"] = "60000" # 60 second timeout
787827
os.environ["MONARCH_HOSTMESH_V1"] = "1"
788828
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1"
789829
# os.environ["FORGE_DISABLE_METRICS"] = "1"

apps/julia-grpo/llama3_8b_julia.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# >>> python -m apps.julia-grpo.main --config apps/julia-grpo/llama3_8b_julia.yaml
33

44
# Global configuration
5-
group_size: 16 # num_generations from unsloth.py
6-
batch_size: 4 # per_device_train_batch_size from unsloth.py
5+
group_size: 8 # num_generations from unsloth.py
6+
batch_size: 2 # per_device_train_batch_size from unsloth.py
77
max_req_tokens: 1024 # max_prompt_length from unsloth.py
88
max_res_tokens: 2048 # max_completion_length from unsloth.py
99
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
10-
off_by_n: 2 # Off by one by default
10+
off_by_n: 0 # Off by one by default
1111

1212
# Main loop configuration
1313
rollout_threads: 1 # Single thread for Julia code generation
@@ -70,7 +70,6 @@ trainer:
7070
name: AdamW
7171
lr: 5e-6 # learning_rate from unsloth.py
7272
eps: 1e-8
73-
weight_decay: 0.01 # weight_decay from unsloth.py
7473
lr_scheduler:
7574
warmup_steps: 0 # warmup_ratio=0.1 * max_steps=500 from unsloth.py
7675
training:

apps/openenv/julia_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
import re
13-
from typing import Dict, Any
13+
from typing import Any, Dict
1414

1515
from forge.observability.metrics import record_metric, Reduce
1616

@@ -143,7 +143,6 @@ def evaluate_julia_response(result, response: str, sample: Dict[str, Any]) -> fl
143143

144144
# Extract reward from result
145145
reward = result.reward if result.reward is not None else 0.0
146-
record_metric("reward/julia/reward", reward, Reduce.MEAN)
147146

148147
obs = result.observation
149148
passed = obs.tests_passed
@@ -211,6 +210,12 @@ def transform_julia_sample(sample: Dict[str, Any], tokenizer) -> Dict[str, Any]
211210
"""
212211
# Validate required fields
213212
if not sample.get("julia_test") or not sample.get("first_test_case"):
213+
# Debug: log why sample was rejected (only for first few)
214+
if not hasattr(transform_julia_sample, "_warned"):
215+
print(
216+
f"WARNING: Sample rejected - missing 'julia_test' or 'first_test_case' field. Sample keys: {list(sample.keys())}"
217+
)
218+
transform_julia_sample._warned = True
214219
return None
215220

216221
# Build prompt

0 commit comments

Comments
 (0)