Skip to content

Commit 3b7ee6d

Browse files
committed
Merge remote-tracking branch 'origin/main' into ref-actor
2 parents 4ca0685 + ccd2377 commit 3b7ee6d

File tree

14 files changed

+329
-256
lines changed

14 files changed

+329
-256
lines changed

apps/grpo/main.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import asyncio
8+
import logging
89
import time
910
from dataclasses import dataclass
1011
from typing import Callable
@@ -15,12 +16,15 @@
1516
from forge.actors.reference_actor import compute_sequence_logprobs, RefModel
1617
from forge.actors.replay_buffer import ReplayBuffer
1718
from forge.controller.actor import ForgeActor
18-
from forge.controller.service import ServiceConfig, spawn_service
19+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
1920
from forge.data.rewards import MathReward, ThinkingReward
2021
from forge.util.metric_logging import get_metric_logger
2122
from monarch.actor import endpoint
2223
from transformers import AutoModelForCausalLM, AutoTokenizer
2324

25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.DEBUG)
27+
2428

2529
@dataclass
2630
class Group:
@@ -242,18 +246,18 @@ async def __call__(self, groups: list[Group]) -> list[float]:
242246
class DatasetActor(ForgeActor):
243247
"""Actor wrapper for HuggingFace dataset to provide async interface."""
244248

245-
def __init__(self, *args, **kwargs):
249+
def __init__(
250+
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
251+
):
246252
super().__init__()
247-
self._setup_dataset(*args, **kwargs)
248253

249-
def _setup_dataset(self, *args, **kwargs):
250254
def gsm8k_to_messages(sample):
251255
question = sample["question"]
252256
full_answer: str = sample["answer"]
253257
answer = full_answer.split("#### ")[1]
254258
return {"question": question, "answer": answer}
255259

256-
ds = load_dataset(*args, **kwargs)
260+
ds = load_dataset(path, config_name, split=split, streaming=streaming)
257261
ds = ds.map(gsm8k_to_messages)
258262
ds = ds.shuffle()
259263
self._iterator = iter(ds)
@@ -279,66 +283,69 @@ async def main():
279283
)
280284

281285
# ---- Setup services ---- #
282-
policy = await spawn_service(
283-
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
284-
Policy,
285-
PolicyConfig(
286-
num_workers=1,
287-
worker_params=WorkerConfig(model=model),
288-
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
286+
(
287+
dataloader,
288+
policy,
289+
trainer,
290+
replay_buffer,
291+
compute_advantages,
292+
ref_model,
293+
reward_actor,
294+
) = await asyncio.gather(
295+
spawn_service(
296+
ServiceConfig(procs_per_replica=1, num_replicas=1),
297+
DatasetActor,
298+
path="openai/gsm8k",
299+
config_name="main",
300+
split="train",
301+
streaming=True,
302+
),
303+
spawn_service(
304+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
305+
Policy,
306+
config=PolicyConfig(
307+
worker_params=WorkerConfig(model=model),
308+
sampling_params=SamplingOverrides(
309+
num_samples=group_size, max_tokens=16
310+
),
311+
),
312+
),
313+
spawn_service(
314+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
315+
Trainer,
316+
learning_rate=1e-5,
317+
beta=0.1,
318+
model_name=model,
319+
),
320+
spawn_service(
321+
ServiceConfig(procs_per_replica=1, num_replicas=1),
322+
ReplayBuffer,
323+
batch_size=4,
324+
max_policy_age=1,
325+
),
326+
spawn_service(
327+
ServiceConfig(procs_per_replica=1, num_replicas=1),
328+
ComputeAdvantages,
329+
gamma=0.99,
330+
lambda_=0.95,
331+
),
332+
spawn_service(
333+
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
334+
RefModel,
335+
model_name=model,
336+
),
337+
spawn_service(
338+
ServiceConfig(procs_per_replica=1, num_replicas=1),
339+
RewardActor,
340+
reward_functions=[MathReward(), ThinkingReward()],
289341
),
290-
)
291-
292-
trainer = await spawn_service(
293-
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
294-
Trainer,
295-
learning_rate=1e-5,
296-
beta=0.1,
297-
model_name=model,
298-
)
299-
300-
replay_buffer = await spawn_service(
301-
ServiceConfig(procs_per_replica=1, num_replicas=1),
302-
ReplayBuffer,
303-
batch_size=4,
304-
max_policy_age=1,
305-
)
306-
307-
dataloader = await spawn_service(
308-
ServiceConfig(procs_per_replica=1, num_replicas=1),
309-
DatasetActor,
310-
"openai/gsm8k",
311-
"main",
312-
split="train",
313-
streaming=True,
314-
)
315-
316-
compute_advantages = await spawn_service(
317-
ServiceConfig(procs_per_replica=1, num_replicas=1),
318-
ComputeAdvantages,
319-
gamma=0.99,
320-
lambda_=0.95,
321-
)
322-
323-
ref_model = await spawn_service(
324-
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
325-
RefModel,
326-
model_name=model,
327-
)
328-
329-
reward_actor = await spawn_service(
330-
ServiceConfig(procs_per_replica=1, num_replicas=1),
331-
RewardActor,
332-
reward_functions=[MathReward(), ThinkingReward()],
333342
)
334343

335344
print("All services initialized successfully!")
336345

337346
# ---- Core RL loops ---- #
338347
async def continuous_rollouts():
339348
rollout_count = 0
340-
# TODO: Move this into setup
341-
asyncio.create_task(policy.run_processing.call())
342349
while True:
343350
sample = await dataloader.__next__.choose()
344351
if sample is None:
@@ -409,6 +416,17 @@ async def continuous_training():
409416
print("Training interrupted by user")
410417
rollout_task.cancel()
411418
training_task.cancel()
419+
finally:
420+
print("Shutting down...")
421+
await asyncio.gather(
422+
shutdown_service(policy),
423+
shutdown_service(trainer),
424+
shutdown_service(replay_buffer),
425+
shutdown_service(dataloader),
426+
shutdown_service(compute_advantages),
427+
shutdown_service(ref_model),
428+
shutdown_service(reward_actor),
429+
)
412430

413431

414432
if __name__ == "__main__":

apps/vllm/main.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import List
1717

1818
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
19-
from forge.controller.service import ServiceConfig, spawn_service
19+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2020
from vllm.outputs import CompletionOutput
2121

2222

@@ -58,9 +58,11 @@ def parse_args() -> Namespace:
5858

5959

6060
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
61+
62+
worker_size = 2
6163
worker_params = WorkerConfig(
6264
model=args.model,
63-
tensor_parallel_size=2,
65+
tensor_parallel_size=worker_size,
6466
pipeline_parallel_size=1,
6567
enforce_eager=True,
6668
vllm_args=None,
@@ -72,35 +74,34 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
7274
)
7375

7476
policy_config = PolicyConfig(
75-
num_workers=2, worker_params=worker_params, sampling_params=sampling_params
77+
worker_params=worker_params, sampling_params=sampling_params
78+
)
79+
service_config = ServiceConfig(
80+
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
7681
)
77-
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)
7882

7983
return policy_config, service_config
8084

8185

8286
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
8387
print("Spawning service...")
8488
policy = await spawn_service(service_config, Policy, config=config)
85-
session_id = await policy.start_session()
8689

87-
print("Starting background processing...")
88-
processing_task = asyncio.create_task(policy.run_processing.call())
90+
async with policy.session():
91+
print("Requesting generation...")
92+
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
8993

90-
print("Requesting generation...")
91-
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
94+
print("\nGeneration Results:")
95+
print("=" * 80)
96+
for batch, response in enumerate(responses):
97+
print(f"Sample {batch + 1}:")
98+
print(f"User: {prompt}")
99+
print(f"Assistant: {response.text}")
100+
print("-" * 80)
92101

93-
print("\nGeneration Results:")
94-
print("=" * 80)
95-
for batch, response in enumerate(responses):
96-
print(f"Sample {batch + 1}:")
97-
print(f"User: {prompt}")
98-
print(f"Assistant: {response.text}")
99-
print("-" * 80)
102+
print("\nShutting down...")
100103

101-
print("\nShutting down...")
102-
await policy.shutdown.call()
103-
await policy.terminate_session(session_id)
104+
await shutdown_service(policy)
104105

105106

106107
if __name__ == "__main__":

0 commit comments

Comments
 (0)