Skip to content

Commit 2693f1f

Browse files
committed
1500 steps working
1 parent 2f13a21 commit 2693f1f

File tree

4 files changed

+729
-24
lines changed

4 files changed

+729
-24
lines changed

apps/julia-grpo/llama3_8b_julia.yaml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# >>> python -m apps.julia-grpo.main --config apps/julia-grpo/llama3_8b_julia.yaml
33

44
# Global configuration
5-
group_size: 8 # num_generations from unsloth.py
5+
group_size: 16 # num_generations from unsloth.py
66
batch_size: 4 # per_device_train_batch_size from unsloth.py
7-
max_req_tokens: 2048 # max_prompt_length from unsloth.py
8-
max_res_tokens: 1024 # max_completion_length from unsloth.py
7+
max_req_tokens: 1024 # max_prompt_length from unsloth.py
8+
max_res_tokens: 2048 # max_completion_length from unsloth.py
99
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
10-
off_by_n: 0 # Off by one by default
10+
off_by_n: 2 # Off by one by default
1111

1212
# Main loop configuration
1313
rollout_threads: 1 # Single thread for Julia code generation
@@ -38,7 +38,9 @@ openenv_config:
3838
env_vars: {} # Additional environment variables if needed
3939
container_timeout_s: 180.0 # Timeout for container operations
4040
request_timeout_s: 120.0 # Timeout for code execution requests
41-
container_memory_gb: 4 # Memory limit for containers
41+
container_memory_gb: 128 # Memory limit for containers
42+
port: 8000 # port for container communication
43+
num_worker: 8 # number of workers
4244

4345
# Policy configuration
4446
policy:
@@ -66,16 +68,16 @@ trainer:
6668
hf_assets_path: hf://${model}
6769
optimizer:
6870
name: AdamW
69-
lr: 1e-5 # learning_rate from unsloth.py
71+
lr: 5e-6 # learning_rate from unsloth.py
7072
eps: 1e-8
7173
weight_decay: 0.01 # weight_decay from unsloth.py
7274
lr_scheduler:
73-
warmup_steps: 50 # warmup_ratio=0.1 * max_steps=500 from unsloth.py
75+
warmup_steps: 0 # warmup_ratio=0.1 * max_steps=500 from unsloth.py
7476
training:
7577
local_batch_size: ${batch_size}
7678
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
7779
max_norm: 1.0
78-
steps: 500 # max_steps from unsloth.py
80+
steps: 3000 # max_steps from unsloth.py
7981
dtype: bfloat16
8082
gc_freq: 1
8183
compile:
@@ -90,11 +92,11 @@ trainer:
9092
disable_loss_parallel: true
9193
checkpoint:
9294
enable: true
93-
folder: "checkpoint_llama3_8b_julia1107"
95+
folder: "checkpoint_llama3_8b_julia1109"
9496
initial_load_path: hf://${model}
9597
initial_load_in_hf: true
9698
last_save_in_hf: true
97-
interval: 100 # save_steps from unsloth.py
99+
interval: 150 # save_steps from unsloth.py
98100
async_mode: "disabled"
99101
activation_checkpoint:
100102
mode: selective

apps/julia-grpo/main.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
281281

282282
# Extract reward from result
283283
reward = result.reward if result.reward is not None else 0.0
284+
record_metric("reward/julia/reward", reward, Reduce.MEAN)
284285
obs = result.observation
285286

286287
passed = obs.tests_passed
@@ -293,19 +294,20 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
293294
print(f" Tests Passed: {passed}")
294295
print(f" Tests Failed: {failed}")
295296
print(f" Total Tests: {total}")
297+
print(f" Exit Code: {obs.exit_code}")
298+
print(f" Code Compiles: {obs.code_compiles}")
296299

297300
if obs.stderr:
298-
print(f" Stderr: {obs.stderr[:200]}")
301+
print(f" Stderr: {obs.stderr[:500]}")
299302
record_metric("reward/julia/has_errors", 1, Reduce.SUM)
300303

301-
if obs.error_message:
302-
print(f" Error Message: {obs.error_message[:200]}")
304+
if obs.stdout:
305+
print(f" Stdout (first 200 chars): {obs.stdout[:200]}")
303306

304307
# Log metrics
305-
record_metric("reward/julia/tests_passed", passed, Reduce.SUM)
306-
record_metric("reward/julia/tests_failed", failed, Reduce.SUM)
307-
record_metric("reward/julia/tests_total", total, Reduce.SUM)
308-
record_metric("reward/julia/pass_rate", reward, Reduce.MEAN)
308+
pass_rate = passed / total if total > 0 else 0.0
309+
310+
record_metric("reward/julia/pass_rate", pass_rate, Reduce.MEAN)
309311

310312
print(f"Final Reward: {reward:.3f}")
311313
print("=" * 80)
@@ -337,7 +339,7 @@ def _extract_code(self, response: str) -> str:
337339
class ComputeAdvantages(ForgeActor):
338340
@endpoint
339341
async def compute(self, group: Group) -> list[float]:
340-
rewards = torch.tensor([[e.reward for e in group]])
342+
rewards = torch.tensor([[e.reward for e in group]], dtype=torch.float32)
341343
mean = rewards.mean(1, keepdim=True)
342344
std = rewards.std(1, keepdim=True)
343345
advantages = (rewards - mean) / (std + 1e-4)
@@ -517,6 +519,14 @@ async def main(cfg: DictConfig):
517519
request_timeout_s = openenv_config.get("request_timeout_s", 120.0)
518520
container_memory_gb = openenv_config.get("container_memory_gb", 4)
519521

522+
# Set PORT and NUM_WORKER environment variables for the Julia server
523+
# These match the Dockerfile defaults
524+
if "PORT" not in env_vars:
525+
env_vars["PORT"] = str(openenv_config.get("port", 8000))
526+
if "NUM_WORKER" not in env_vars:
527+
env_vars["NUM_WORKER"] = str(openenv_config.get("num_worker", 4))
528+
if "JULIA_MAX_WORKERS" not in env_vars:
529+
env_vars["JULIA_MAX_WORKERS"] = str(openenv_config.get("julia_max_workers", 16))
520530
julia_env_actor = await GenericOpenEnvActor.options(
521531
**cfg.actors.julia_env
522532
).as_actor(
@@ -587,12 +597,14 @@ async def continuous_rollouts():
587597
responses: list[Completion] = await policy.generate.route(prompt)
588598
t.step("policy_generation")
589599

590-
# Construct episodes and calculate rewards
600+
# Construct episodes and calculate rewards in parallel
591601
episodes = []
592602
input_ids = torch.ones(
593603
(group_size, max_req_tokens + max_res_tokens),
594604
dtype=torch.long,
595605
)
606+
607+
# Create episodes first
596608
for i, response in enumerate(responses):
597609
episode = Episode(
598610
episode_id=str(uuid.uuid4()),
@@ -602,12 +614,20 @@ async def continuous_rollouts():
602614
target=target,
603615
completion=response,
604616
)
605-
episode.reward = await reward_actor.evaluate_response.route(
617+
episodes.append(episode)
618+
619+
# Evaluate all rewards in parallel
620+
reward_tasks = [
621+
reward_actor.evaluate_response.route(
606622
prompt=prompt, response=response.text, target=target
607623
)
608-
episodes.append(episode)
624+
for response in responses
625+
]
626+
rewards = await asyncio.gather(*reward_tasks)
609627

610-
# Build input_ids for reference logprobs
628+
# Assign rewards and build input_ids
629+
for i, (episode, reward) in enumerate(zip(episodes, rewards)):
630+
episode.reward = reward
611631
input_ids[i, :max_req_tokens] = episode.request_tensor
612632
input_ids[i, max_req_tokens:] = episode.response_tensor
613633

0 commit comments

Comments
 (0)