Skip to content

Commit 45eb52d

Browse files
committed
resolving merge conflicts
2 parents 264ccc8 + 2bcf56c commit 45eb52d

File tree

10 files changed

+438
-183
lines changed

10 files changed

+438
-183
lines changed

apps/grpo/main.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
8+
79
import asyncio
810
import logging
911
import uuid
@@ -13,13 +15,16 @@
1315
import torch
1416
import torch.nn.functional as F
1517
from datasets import load_dataset
16-
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
18+
from forge.actors.policy import Policy
1719
from forge.actors.replay_buffer import ReplayBuffer
20+
from forge.cli.config import parse
1821
from forge.controller.actor import ForgeActor
1922
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2023
from forge.data.rewards import MathReward, ThinkingReward
2124
from forge.util.metric_logging import get_metric_logger
2225
from monarch.actor import endpoint
26+
from omegaconf import DictConfig
27+
from src.forge.data.utils import exclude_service
2328
from torch import nn
2429
from transformers import AutoModelForCausalLM
2530
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -286,11 +291,11 @@ async def forward(self, episode: Episode) -> torch.Tensor:
286291
class DatasetActor(ForgeActor):
287292
"""Actor wrapper for HuggingFace dataset to provide async interface."""
288293

289-
path: str
290-
revision: str
291-
data_split: str
292-
streaming: bool
293-
model: str
294+
path: str = "openai/gsm8k"
295+
revision: str = "main"
296+
data_split: str = "train"
297+
streaming: bool = True
298+
model: str = "Qwen/Qwen3-1.7B-Base"
294299

295300
@endpoint
296301
def setup(self):
@@ -326,12 +331,13 @@ async def pad_token(self):
326331
return self.tokenizer.pad_token_id
327332

328333

329-
async def main():
334+
async def main(cfg: DictConfig):
330335
"""Main GRPO training loop with rollout and training processes."""
331-
group_size = 4
332-
model = "Qwen/Qwen3-1.7B-Base"
333-
max_req_tokens = 512
334-
max_res_tokens = 128
336+
# Get parameters from config with fallbacks
337+
group_size = cfg.group_size
338+
model = cfg.model
339+
max_req_tokens = cfg.max_req_tokens
340+
max_res_tokens = cfg.max_res_tokens
335341

336342
# ---- Setup WandB Logger ---- #
337343
logger = get_metric_logger(
@@ -351,47 +357,37 @@ async def main():
351357
reward_actor,
352358
) = await asyncio.gather(
353359
spawn_service(
354-
ServiceConfig(procs_per_replica=1, num_replicas=1),
360+
ServiceConfig(**cfg.dataset.service),
355361
DatasetActor,
356-
path="openai/gsm8k",
357-
revision="main",
358-
data_split="train",
359-
streaming=True,
360-
model=model,
362+
**exclude_service(cfg.dataset),
361363
),
362364
spawn_service(
363-
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
365+
ServiceConfig(**cfg.policy.service),
364366
Policy,
365-
config=PolicyConfig(
366-
worker_params=WorkerConfig(model=model),
367-
sampling_params=SamplingOverrides(
368-
n=group_size, max_tokens=max_res_tokens
369-
),
370-
),
367+
**exclude_service(cfg.policy),
371368
),
372369
spawn_service(
373-
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
370+
ServiceConfig(**cfg.trainer.service),
374371
Trainer,
375-
learning_rate=1e-5,
376372
model_name=model,
373+
**exclude_service(cfg.trainer),
377374
),
378375
spawn_service(
379-
ServiceConfig(procs_per_replica=1, num_replicas=1),
376+
ServiceConfig(**cfg.replay_buffer.service),
380377
ReplayBuffer,
381-
batch_size=4,
382-
max_policy_age=1,
378+
**exclude_service(cfg.replay_buffer),
383379
),
384380
spawn_service(
385-
ServiceConfig(procs_per_replica=1, num_replicas=1),
381+
ServiceConfig(**cfg.compute_advantages.service),
386382
ComputeAdvantages,
387383
),
388384
spawn_service(
389-
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
385+
ServiceConfig(**cfg.ref_model.service),
390386
RefModel,
391387
model_name=model,
392388
),
393389
spawn_service(
394-
ServiceConfig(procs_per_replica=1, num_replicas=1),
390+
ServiceConfig(**cfg.reward_actor.service),
395391
RewardActor,
396392
reward_functions=[MathReward(), ThinkingReward()],
397393
),
@@ -485,5 +481,10 @@ async def continuous_training():
485481
)
486482

487483

484+
@parse
485+
def recipe_main(cfg: DictConfig) -> None:
486+
asyncio.run(main(cfg))
487+
488+
488489
if __name__ == "__main__":
489-
asyncio.run(main())
490+
recipe_main()

apps/grpo/qwen3_1_7b.yaml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# GRPO Training Configuration
2+
3+
# Global configuration
4+
group_size: 4
5+
batch_size: 4
6+
max_req_tokens: 512
7+
max_res_tokens: 128
8+
model: "Qwen/Qwen3-1.7B-Base"
9+
10+
# Dataset configuration
11+
dataset:
12+
path: "openai/gsm8k"
13+
revision: "main"
14+
data_split: "train"
15+
streaming: true
16+
service:
17+
procs_per_replica: 1
18+
num_replicas: 1
19+
with_gpus: false
20+
21+
# Policy configuration
22+
policy:
23+
engine_config:
24+
model: ${model}
25+
tensor_parallel_size: 1
26+
pipeline_parallel_size: 1
27+
enforce_eager: true
28+
sampling_config:
29+
n: 4
30+
max_tokens: 128
31+
temperature: 1.0
32+
top_p: 1.0
33+
service:
34+
procs_per_replica: 1
35+
num_replicas: 1
36+
with_gpus: true
37+
38+
# Trainer configuration
39+
trainer:
40+
learning_rate: 1e-5
41+
service:
42+
procs_per_replica: 1
43+
num_replicas: 1
44+
with_gpus: true
45+
46+
# Replay buffer configuration
47+
replay_buffer:
48+
batch_size: ${batch_size}
49+
max_policy_age: 0
50+
service:
51+
procs_per_replica: 1
52+
num_replicas: 1
53+
with_gpus: false
54+
55+
# Compute advantages configuration
56+
compute_advantages:
57+
service:
58+
procs_per_replica: 1
59+
num_replicas: 1
60+
with_gpus: false
61+
62+
# Reference model configuration
63+
ref_model:
64+
service:
65+
procs_per_replica: 1
66+
num_replicas: 1
67+
with_gpus: true
68+
69+
# Reward actor configuration
70+
reward_actor:
71+
service:
72+
procs_per_replica: 1
73+
num_replicas: 1
74+
with_gpus: false

apps/vllm/llama3_8b.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
policy:
2+
engine_config:
3+
model: "meta-llama/Llama-3.1-8B-Instruct"
4+
tensor_parallel_size: 2
5+
pipeline_parallel_size: 1
6+
enforce_eager: true
7+
sampling_config:
8+
n: 2
9+
guided_decoding: false
10+
max_tokens: 512
11+
available_devices: null
12+
service:
13+
procs_per_replica: 2
14+
num_replicas: 1
15+
with_gpus: true
16+
17+
18+
# Optional, otherwise argparse fallback kicks in
19+
prompt: "Tell me a joke"

apps/vllm/main.py

Lines changed: 21 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,95 +6,34 @@
66

77
"""To run:
88
export HF_HUB_DISABLE_XET=1
9-
python -m apps.vllm.main --guided-decoding --num-samples 3
10-
9+
python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml
1110
"""
1211

13-
import argparse
1412
import asyncio
15-
from argparse import Namespace
1613

17-
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
14+
from forge.actors.policy import Policy
15+
from forge.cli.config import parse
1816
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
19-
from vllm.outputs import RequestOutput
20-
from vllm.transformers_utils.tokenizer import get_tokenizer
21-
22-
23-
async def main():
24-
"""Main application for running vLLM policy inference."""
25-
args = parse_args()
2617

27-
# Create configuration objects
28-
policy_config, service_config = get_configs(args)
29-
30-
# Resolve the Prompts
31-
if args.prompt is None:
32-
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke"
33-
else:
34-
prompt = args.prompt
35-
36-
# format prompt
37-
tokenizer = get_tokenizer(policy_config.worker_params.model)
38-
messages = [{"role": "user", "content": prompt}]
39-
prompt = tokenizer.apply_chat_template(
40-
messages, tokenize=False, add_generation_prompt=True
41-
)
42-
43-
# Run the policy
44-
await run_vllm(service_config, policy_config, prompt)
45-
46-
47-
def parse_args() -> Namespace:
48-
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application")
49-
parser.add_argument(
50-
"--model",
51-
type=str,
52-
default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct",
53-
help="Model to use",
54-
)
55-
parser.add_argument(
56-
"--num-samples", type=int, default=2, help="Number of samples to generate"
57-
)
58-
parser.add_argument(
59-
"--guided-decoding", action="store_true", help="Enable guided decoding"
60-
)
61-
parser.add_argument(
62-
"--prompt", type=str, default=None, help="Custom prompt to use for generation"
63-
)
64-
return parser.parse_args()
18+
from omegaconf import DictConfig
19+
from src.forge.data.utils import exclude_service
20+
from vllm.outputs import RequestOutput
6521

6622

67-
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
23+
async def run(cfg: DictConfig):
6824

69-
worker_size = 2
70-
worker_params = WorkerConfig(
71-
model=args.model,
72-
tensor_parallel_size=worker_size,
73-
pipeline_parallel_size=1,
74-
enforce_eager=True,
75-
vllm_args=None,
76-
)
25+
if (prompt := cfg.get("prompt")) is None:
26+
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
27+
prompt = "What is 3+5?" if gd else "Tell me a joke"
7728

78-
sampling_params = SamplingOverrides(
79-
n=args.num_samples,
80-
guided_decoding=args.guided_decoding,
81-
max_tokens=16,
82-
)
29+
print("Spawning service...")
8330

84-
policy_config = PolicyConfig(
85-
worker_params=worker_params, sampling_params=sampling_params
86-
)
87-
service_config = ServiceConfig(
88-
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
31+
policy = await spawn_service(
32+
ServiceConfig(**cfg.policy.service),
33+
Policy,
34+
**exclude_service(cfg.policy),
8935
)
9036

91-
return policy_config, service_config
92-
93-
94-
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
95-
print("Spawning service...")
96-
policy = await spawn_service(service_config, Policy, config=config)
97-
9837
async with policy.session():
9938
print("Requesting generation...")
10039
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)
@@ -112,5 +51,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
11251
await shutdown_service(policy)
11352

11453

54+
@parse
55+
def recipe_main(cfg: DictConfig) -> None:
56+
asyncio.run(run(cfg))
57+
58+
11559
if __name__ == "__main__":
116-
asyncio.run(main())
60+
recipe_main()

0 commit comments

Comments
 (0)