-
Notifications
You must be signed in to change notification settings - Fork 33
Initial multi-host support #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
3ab5c51
f4691ca
54b4444
a103ae9
e0f9228
7c1faab
6fb46ff
46e92eb
bc06a0a
373fb05
9b3016e
396b521
b12810e
130442c
fbe9b45
e3a02d6
c139154
39d399c
52592a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# GRPO Training Configuration | ||
allenwang28 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# Currently a fork of the main yaml, this just shows | ||
# placement of trainer and inference servers on separate hosts. | ||
|
||
# Global configuration | ||
group_size: 8 | ||
batch_size: 16 | ||
max_req_tokens: 512 | ||
max_res_tokens: 512 | ||
model: "Qwen/Qwen3-1.7B" | ||
|
||
# Dataset configuration | ||
dataset: | ||
path: "openai/gsm8k" | ||
revision: "main" | ||
data_split: "train" | ||
streaming: true | ||
model: ${model} | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
|
||
# Policy configuration | ||
policy: | ||
engine_config: | ||
model: ${model} | ||
tensor_parallel_size: 1 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: false | ||
sampling_config: | ||
n: ${group_size} | ||
max_tokens: ${max_res_tokens} | ||
temperature: 1.0 | ||
top_p: 1.0 | ||
service: | ||
procs_per_replica: 1 | ||
hosts_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Trainer configuration | ||
trainer: | ||
model_name: ${model} | ||
learning_rate: 1e-5 | ||
service: | ||
procs_per_replica: 1 | ||
hosts_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Replay buffer configuration | ||
replay_buffer: | ||
batch_size: ${batch_size} | ||
max_policy_age: 1 # Async by 1 | ||
dp_size: 1 | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
|
||
# Compute advantages configuration | ||
compute_advantages: | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
|
||
# Reference model configuration | ||
ref_model: | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Reward actor configuration | ||
reward_actor: | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# NOTE - this won't work until we have proper HostMesh support | ||
policy: | ||
engine_config: | ||
model: "deepseek-ai/DeepSeek-R1-0528" | ||
tensor_parallel_size: 16 | ||
pipeline_parallel_size: 1 | ||
enable_expert_parallel: true | ||
# enforce_eager: true | ||
sampling_config: | ||
n: 2 | ||
guided_decoding: false | ||
max_tokens: 512 | ||
available_devices: null | ||
service: | ||
procs_per_replica: 8 | ||
hosts_per_replica: 2 | ||
allenwang28 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,14 +11,20 @@ | |
|
||
import asyncio | ||
|
||
import os | ||
|
||
from forge.actors.policy import Policy | ||
from forge.cli.config import parse | ||
from forge.controller.provisioner import shutdown | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
|
||
from omegaconf import DictConfig | ||
from src.forge.data.utils import exclude_service | ||
from vllm.outputs import RequestOutput | ||
|
||
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" | ||
|
||
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" | ||
|
||
|
||
async def run(cfg: DictConfig): | ||
|
||
|
@@ -27,28 +33,29 @@ async def run(cfg: DictConfig): | |
prompt = "What is 3+5?" if gd else "Tell me a joke" | ||
|
||
print("Spawning service...") | ||
|
||
policy = await spawn_service( | ||
ServiceConfig(**cfg.policy.service), | ||
Policy, | ||
**exclude_service(cfg.policy), | ||
) | ||
|
||
async with policy.session(): | ||
print("Requesting generation...") | ||
response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
try: | ||
async with policy.session(): | ||
print("Requesting generation...") | ||
response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
|
||
print("\nGeneration Results:") | ||
print("=" * 80) | ||
for batch, response in enumerate(response_output.outputs): | ||
print(f"Sample {batch + 1}:") | ||
print(f"User: {prompt}") | ||
print(f"Assistant: {response.text}") | ||
print("-" * 80) | ||
print("\nGeneration Results:") | ||
print("=" * 80) | ||
for batch, response in enumerate(response_output.outputs): | ||
print(f"Sample {batch + 1}:") | ||
print(f"User: {prompt}") | ||
print(f"Assistant: {response.text}") | ||
print("-" * 80) | ||
|
||
finally: | ||
print("\nShutting down...") | ||
|
||
await shutdown_service(policy) | ||
await shutdown_service(policy) | ||
await shutdown() | ||
|
||
|
||
@parse | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
policy: | ||
engine_config: | ||
model: "Qwen/Qwen2.5-32B" | ||
tensor_parallel_size: 4 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
sampling_config: | ||
n: 2 | ||
guided_decoding: false | ||
max_tokens: 512 | ||
available_devices: null | ||
service: | ||
procs_per_replica: 4 | ||
hosts_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH --job-name=forge | ||
#SBATCH --output=slogs/forge.out | ||
#SBATCH --error=slogs/forge.err | ||
#SBATCH --partition=h100-low # or h100-high / h100-prod / all | ||
#SBATCH --nodes=1 # 1 node | ||
#SBATCH --ntasks=1 # 1 task (process) | ||
#SBATCH --gres=gpu:8 # request 8 GPUs | ||
#SBATCH --time=01:00:00 # walltime hh:mm:ss | ||
|
||
unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE | ||
echo "Running on $SLURM_JOB_NODELIST" | ||
python -m apps.grpo.main --config=apps/grpo/multihost.yaml |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
|
||
import asyncio | ||
import os | ||
import sys | ||
|
@@ -91,6 +93,7 @@ class EngineConfig(EngineArgs): | |
tensor_parallel_size: int = 1 | ||
pipeline_parallel_size: int = 1 | ||
enforce_eager: bool = False | ||
enable_expert_parallel: bool = False | ||
|
||
# Original method returns False when not run in the main thread | ||
_is_v1_supported_oracle = lambda *_: True | ||
|
@@ -103,7 +106,8 @@ def from_dict(cls, d: Mapping): | |
return cls(**valid_args) | ||
|
||
def create_vllm_config(self) -> VllmConfig: | ||
# This is not a typo: EngineArgs.create_engine_config | ||
"""Converts the current EngineConfig into vLLM's vLLMConfig.""" | ||
# Note: EngineArgs.create_engine_config | ||
# creates a VllmConfig | ||
return self.create_engine_config(UsageContext.LLM_CLASS) | ||
|
||
|
@@ -144,9 +148,15 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
# automatically. | ||
worker_procs = await get_proc_mesh(process_config=process_config) | ||
|
||
# TODO - we will want to ensure colocation with workers | ||
# TODO - issues/144 we will want to ensure colocation with workers | ||
# We're currently locating the Policy on the local host proc mesh | ||
# vLLM initialization without setting env variables at proc_mesh creation | ||
# level leads to issues. | ||
# Once we can create multiple proc meshes on a host mesh, we can ensure | ||
# host colocation | ||
policy_proc_config = copy(process_config) | ||
policy_proc_config.num_procs = 1 | ||
policy_proc_config.num_hosts = None | ||
policy_proc_config.with_gpus = False | ||
|
||
policy_proc = await get_proc_mesh(process_config=policy_proc_config) | ||
|
@@ -201,7 +211,7 @@ async def setup(self): | |
await self.policy_worker.setup.call() | ||
|
||
self.request_id = 0 | ||
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} | ||
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} | ||
self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() | ||
|
||
# Setup sampling params | ||
|
@@ -462,10 +472,6 @@ async def setup_kv_cache(self): | |
self.worker.initialize_cache(kv_cache_config.num_blocks, 0) | ||
return kv_cache_config | ||
|
||
@endpoint | ||
async def get_vllm_config(self) -> VllmConfig: | ||
return self.vllm_config | ||
|
||
@endpoint | ||
async def _get_model_params(self) -> dict[str, torch.Tensor]: | ||
model = self.worker.model_runner.model | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be awesome!