Skip to content

Commit e4723bb

Browse files
committed
Merge branch 'main' into ungroup
2 parents 3e32264 + 7357daa commit e4723bb

28 files changed

+2031
-456
lines changed

apps/grpo/main.py

Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import asyncio
88
import copy
9+
import logging
910
import time
1011
import uuid
1112
from dataclasses import dataclass
@@ -15,15 +16,18 @@
1516
from datasets import load_dataset
1617
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
1718
from forge.actors.replay_buffer import ReplayBuffer
18-
from forge.controller import ServiceConfig, spawn_service
1919
from forge.controller.actor import ForgeActor
20+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2021
from forge.data.rewards import MathReward, ThinkingReward
2122
from forge.util.metric_logging import get_metric_logger
2223
from monarch.actor import endpoint
2324
from torch import nn
2425
from transformers import AutoModelForCausalLM
2526
from vllm.transformers_utils.tokenizer import get_tokenizer
2627

28+
logger = logging.getLogger(__name__)
29+
logger.setLevel(logging.DEBUG)
30+
2731

2832
def compute_logprobs(
2933
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
@@ -365,66 +369,60 @@ async def main():
365369
)
366370

367371
# ---- Setup services ---- #
368-
default_service_cfg = ServiceConfig(
369-
procs_per_replica=1,
370-
num_replicas=1,
371-
)
372-
373-
policy = await spawn_service(
374-
default_service_cfg,
375-
Policy,
376-
PolicyConfig(
377-
num_workers=1,
378-
worker_params=WorkerConfig(model=model),
379-
sampling_params=SamplingOverrides(n=group_size, max_tokens=max_res_tokens),
380-
available_devices="3",
372+
(
373+
dataloader,
374+
policy,
375+
trainer,
376+
replay_buffer,
377+
compute_advantages,
378+
ref_model,
379+
reward_actor,
380+
) = await asyncio.gather(
381+
spawn_service(
382+
ServiceConfig(procs_per_replica=1, num_replicas=1),
383+
DatasetActor,
384+
path="openai/gsm8k",
385+
name="main",
386+
data_split="train",
387+
streaming=True,
388+
model=model,
389+
),
390+
spawn_service(
391+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
392+
Policy,
393+
config=PolicyConfig(
394+
worker_params=WorkerConfig(model=model),
395+
sampling_params=SamplingOverrides(
396+
n=group_size, max_tokens=max_res_tokens
397+
),
398+
),
399+
),
400+
spawn_service(
401+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
402+
Trainer,
403+
learning_rate=1e-5,
404+
model_name=model,
405+
),
406+
spawn_service(
407+
ServiceConfig(procs_per_replica=1, num_replicas=1),
408+
ReplayBuffer,
409+
batch_size=4,
410+
max_policy_age=1,
411+
),
412+
spawn_service(
413+
ServiceConfig(procs_per_replica=1, num_replicas=1),
414+
ComputeAdvantages,
415+
),
416+
spawn_service(
417+
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
418+
RefModel,
419+
model=titan_model,
420+
),
421+
spawn_service(
422+
ServiceConfig(procs_per_replica=1, num_replicas=1),
423+
RewardActor,
424+
reward_functions=[MathReward(), ThinkingReward()],
381425
),
382-
)
383-
384-
trainer = await spawn_service(
385-
default_service_cfg,
386-
Trainer,
387-
learning_rate=1e-5,
388-
beta=0.1,
389-
model_name=model,
390-
device=torch.device("cuda:1"),
391-
)
392-
393-
replay_buffer = await spawn_service(
394-
default_service_cfg,
395-
ReplayBuffer,
396-
batch_size=4,
397-
max_policy_age=1,
398-
)
399-
400-
dataloader = await spawn_service(
401-
default_service_cfg,
402-
DatasetActor,
403-
"openai/gsm8k",
404-
"main",
405-
data_split="train",
406-
streaming=True,
407-
model=model,
408-
)
409-
410-
compute_advantages = await spawn_service(
411-
default_service_cfg,
412-
ComputeAdvantages,
413-
gamma=0.99,
414-
lambda_=0.95,
415-
)
416-
417-
ref_model = await spawn_service(
418-
default_service_cfg,
419-
RefModel,
420-
model_name=model,
421-
device=torch.device("cuda:2"),
422-
)
423-
424-
reward_actor = await spawn_service(
425-
default_service_cfg,
426-
RewardActor,
427-
reward_functions=[MathReward(), ThinkingReward()],
428426
)
429427

430428
print("All services initialized successfully!")
@@ -433,8 +431,6 @@ async def main():
433431
async def continuous_rollouts():
434432
rollout_count = 0
435433
pad_id = dataloader.pad_token.choose()
436-
# TODO: Move this into setup
437-
asyncio.create_task(policy.run_processing.call())
438434
while True:
439435
sample = await dataloader.sample.choose()
440436
if sample is None:
@@ -501,6 +497,17 @@ async def continuous_training():
501497
print("Training interrupted by user")
502498
rollout_task.cancel()
503499
training_task.cancel()
500+
finally:
501+
print("Shutting down...")
502+
await asyncio.gather(
503+
shutdown_service(policy),
504+
shutdown_service(trainer),
505+
shutdown_service(replay_buffer),
506+
shutdown_service(dataloader),
507+
shutdown_service(compute_advantages),
508+
shutdown_service(ref_model),
509+
shutdown_service(reward_actor),
510+
)
504511

505512

506513
if __name__ == "__main__":

apps/rl/llama3_8b.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ trainer:
1818
processes:
1919
scheduler: local # local | mast (not supported yet)
2020
num_hosts: 1
21+
with_gpus: True
2122
num_procs: 4
2223

2324
optimizer:
@@ -33,9 +34,11 @@ trainer:
3334
seq_len: 2048
3435
max_norm: 1.0
3536
steps: 5
36-
compile: false
3737
dataset: "c4"
3838

39+
compile:
40+
enable: false
41+
3942
parallelism:
4043
data_parallel_replicate_degree: 1
4144
data_parallel_shard_degree: -1
@@ -65,6 +68,7 @@ replay_buffer:
6568
processes:
6669
scheduler: local # local | mast (not supported yet)
6770
num_hosts: 1
71+
with_gpus: False
6872
num_procs: 1
6973

7074
# policy:

apps/rl/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from forge.controller import spawn_actors
2121
from omegaconf import DictConfig
2222

23-
2423
logger = logging.getLogger(__name__)
2524
logger.setLevel(logging.INFO)
2625

@@ -30,7 +29,7 @@ async def run(cfg: DictConfig):
3029
spawn_actors(
3130
name="trainer",
3231
actor_cls=RLTrainer,
33-
cfg={"config": cfg.trainer},
32+
cfg=cfg.trainer,
3433
processes=cfg.trainer.pop("processes"),
3534
set_address=True,
3635
),

apps/sft_v2/llama3_8b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ comm:
1414
model:
1515
name: llama3
1616
flavor: 8B
17-
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
17+
tokenizer_path: /tmp/Llama-3.1-8B-Instruct
1818

1919
processes:
2020
scheduler: local # local | mast (not supported yet)
2121
num_hosts: 1
2222
num_procs: 8
23+
num_gpus: 8
2324

2425
optimizer:
2526
name: AdamW

apps/sft_v2/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
"""To run:
88
9-
python -m apps.sft.main --config apps/sft/llama3_8b.yaml
9+
python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml
1010
1111
"""
1212

apps/vllm/main.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""To run:
8-
8+
export HF_HUB_DISABLE_XET=1
99
python -m apps.vllm.main --guided-decoding --num-samples 3
1010
1111
"""
@@ -16,8 +16,7 @@
1616
from typing import List
1717

1818
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
19-
from forge.controller.service import ServiceConfig
20-
from forge.controller.spawn import spawn_service
19+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2120
from vllm.outputs import CompletionOutput, RequestOutput
2221

2322

@@ -66,9 +65,11 @@ def parse_args() -> Namespace:
6665

6766

6867
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
68+
69+
worker_size = 2
6970
worker_params = WorkerConfig(
7071
model=args.model,
71-
tensor_parallel_size=2,
72+
tensor_parallel_size=worker_size,
7273
pipeline_parallel_size=1,
7374
enforce_eager=True,
7475
vllm_args=None,
@@ -81,36 +82,35 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
8182
)
8283

8384
policy_config = PolicyConfig(
84-
num_workers=2, worker_params=worker_params, sampling_params=sampling_params
85+
worker_params=worker_params, sampling_params=sampling_params
86+
)
87+
service_config = ServiceConfig(
88+
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
8589
)
86-
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)
8790

8891
return policy_config, service_config
8992

9093

9194
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
9295
print("Spawning service...")
9396
policy = await spawn_service(service_config, Policy, config=config)
94-
session_id = await policy.start_session()
95-
96-
print("Starting background processing...")
97-
processing_task = asyncio.create_task(policy.run_processing.call())
98-
99-
print("Requesting generation...")
100-
request_output: RequestOutput = await policy.generate.choose(prompt=prompt)
101-
responses: List[CompletionOutput] = request_output.outputs
102-
103-
print("\nGeneration Results:")
104-
print("=" * 80)
105-
for batch, response in enumerate(responses):
106-
print(f"Sample {batch + 1}:")
107-
print(f"User: {prompt}")
108-
print(f"Assistant: {response.text}")
109-
print("-" * 80)
110-
111-
print("\nShutting down...")
112-
await policy.shutdown.call()
113-
await policy.terminate_session(session_id)
97+
98+
async with policy.session():
99+
print("Requesting generation...")
100+
request_output: RequestOutput = await policy.generate.choose(prompt=prompt)
101+
responses: List[CompletionOutput] = request_output.outputs
102+
103+
print("\nGeneration Results:")
104+
print("=" * 80)
105+
for batch, response in enumerate(responses):
106+
print(f"Sample {batch + 1}:")
107+
print(f"User: {prompt}")
108+
print(f"Assistant: {response.text}")
109+
print("-" * 80)
110+
111+
print("\nShutting down...")
112+
113+
await shutdown_service(policy)
114114

115115

116116
if __name__ == "__main__":

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"tokenizers",
2323
# Miscellaneous
2424
"omegaconf",
25+
"wandb",
2526
]
2627
dynamic = ["version"]
2728

src/forge/actors/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]
7+
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]
88

99

1010
def __getattr__(name):
@@ -24,5 +24,9 @@ def __getattr__(name):
2424
from .replay_buffer import ReplayBuffer
2525

2626
return ReplayBuffer
27+
elif name == "TitanRefModel":
28+
from .reference_actor import TitanRefModel
29+
30+
return TitanRefModel
2731
else:
2832
raise AttributeError(f"module {__name__} has no attribute {name}")

0 commit comments

Comments
 (0)