-
Notifications
You must be signed in to change notification settings - Fork 33
Initial multi-host support #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
3ab5c51
f4691ca
54b4444
a103ae9
e0f9228
7c1faab
6fb46ff
46e92eb
bc06a0a
373fb05
9b3016e
396b521
b12810e
130442c
fbe9b45
e3a02d6
c139154
39d399c
52592a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # GRPO Training Configuration | ||
allenwang28 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Currently a fork of the main yaml, this just shows | ||
| # placement of trainer and inference servers on separate hosts. | ||
|
|
||
| # Global configuration | ||
| group_size: 4 | ||
| batch_size: 4 | ||
| max_req_tokens: 512 | ||
| max_res_tokens: 128 | ||
| model: "Qwen/Qwen3-1.7B-Base" | ||
|
|
||
| # Dataset configuration | ||
| dataset: | ||
| path: "openai/gsm8k" | ||
| revision: "main" | ||
| data_split: "train" | ||
| streaming: true | ||
| service: | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: false | ||
|
|
||
| # Policy configuration | ||
| policy: | ||
| engine_config: | ||
| model: ${model} | ||
| tensor_parallel_size: 1 | ||
| pipeline_parallel_size: 1 | ||
| enforce_eager: true | ||
| sampling_config: | ||
| n: 4 | ||
| max_tokens: 128 | ||
| temperature: 1.0 | ||
| top_p: 1.0 | ||
| service: | ||
| hosts_per_replica: 1 # Places on a remote node | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: true | ||
|
|
||
| # Trainer configuration | ||
| trainer: | ||
| learning_rate: 1e-5 | ||
| service: | ||
| hosts_per_replica: 1 # Places on a remote node | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: true | ||
|
|
||
| # Replay buffer configuration | ||
| replay_buffer: | ||
| batch_size: ${batch_size} | ||
| max_policy_age: 0 | ||
| dp_size: 1 | ||
| service: | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: false | ||
|
|
||
| # Compute advantages configuration | ||
| compute_advantages: | ||
| service: | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: false | ||
|
|
||
| # Reference model configuration | ||
| ref_model: | ||
| service: | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: true | ||
|
|
||
| # Reward actor configuration | ||
| reward_actor: | ||
| service: | ||
| procs_per_replica: 1 | ||
| num_replicas: 1 | ||
| with_gpus: false | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # NOTE - this won't work until we have proper HostMesh support | ||
| policy: | ||
| engine_config: | ||
| model: "deepseek-ai/DeepSeek-R1-0528" | ||
| tensor_parallel_size: 16 | ||
| pipeline_parallel_size: 1 | ||
| enable_expert_parallel: true | ||
| # enforce_eager: true | ||
| sampling_config: | ||
| n: 2 | ||
| guided_decoding: false | ||
| max_tokens: 512 | ||
| available_devices: null | ||
| service: | ||
| procs_per_replica: 8 | ||
| hosts_per_replica: 2 | ||
allenwang28 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| num_replicas: 1 | ||
| with_gpus: true | ||
|
|
||
|
|
||
| # Optional, otherwise argparse fallback kicks in | ||
| prompt: "Tell me a joke" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,12 @@ | |
| from omegaconf import DictConfig | ||
| from src.forge.data.utils import exclude_service | ||
| from vllm.outputs import RequestOutput | ||
| from forge.controller.provisioner import shutdown | ||
|
|
||
| import os | ||
|
|
||
| os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" | ||
|
||
| os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" | ||
|
|
||
|
|
||
| async def run(cfg: DictConfig): | ||
|
|
@@ -27,28 +33,29 @@ async def run(cfg: DictConfig): | |
| prompt = "What is 3+5?" if gd else "Tell me a joke" | ||
|
|
||
| print("Spawning service...") | ||
|
|
||
| policy = await spawn_service( | ||
| ServiceConfig(**cfg.policy.service), | ||
| Policy, | ||
| **exclude_service(cfg.policy), | ||
| ) | ||
|
|
||
| async with policy.session(): | ||
| print("Requesting generation...") | ||
| response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
|
|
||
| print("\nGeneration Results:") | ||
| print("=" * 80) | ||
| for batch, response in enumerate(response_output.outputs): | ||
| print(f"Sample {batch + 1}:") | ||
| print(f"User: {prompt}") | ||
| print(f"Assistant: {response.text}") | ||
| print("-" * 80) | ||
|
|
||
| print("\nShutting down...") | ||
|
|
||
| await shutdown_service(policy) | ||
| try: | ||
| async with policy.session(): | ||
| print("Requesting generation...") | ||
| response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
|
|
||
| print("\nGeneration Results:") | ||
| print("=" * 80) | ||
| for batch, response in enumerate(response_output.outputs): | ||
| print(f"Sample {batch + 1}:") | ||
| print(f"User: {prompt}") | ||
| print(f"Assistant: {response.text}") | ||
| print("-" * 80) | ||
|
|
||
| print("\nShutting down...") | ||
allenwang28 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| finally: | ||
| await shutdown_service(policy) | ||
| await shutdown() | ||
|
|
||
|
|
||
| @parse | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| policy: | ||
| engine_config: | ||
| model: "Qwen/Qwen2.5-32B" | ||
| tensor_parallel_size: 4 | ||
| pipeline_parallel_size: 1 | ||
| enforce_eager: true | ||
| sampling_config: | ||
| n: 2 | ||
| guided_decoding: false | ||
| max_tokens: 512 | ||
| available_devices: null | ||
| service: | ||
| procs_per_replica: 4 | ||
| num_replicas: 1 | ||
| hosts_per_replica: 1 | ||
| with_gpus: true | ||
|
|
||
|
|
||
| # Optional, otherwise argparse fallback kicks in | ||
| prompt: "Tell me a joke" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| #!/bin/bash | ||
|
||
| #SBATCH --job-name=forge | ||
| #SBATCH --output=slogs/forge.out | ||
| #SBATCH --error=slogs/forge.err | ||
| #SBATCH --partition=h100-low # or h100-high / h100-prod / all | ||
| #SBATCH --nodes=1 # 1 node | ||
| #SBATCH --ntasks=1 # 1 task (process) | ||
| #SBATCH --gres=gpu:8 # request 8 GPUs | ||
| #SBATCH --time=01:00:00 # walltime hh:mm:ss | ||
|
|
||
| unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE | ||
| echo "Running on $SLURM_JOB_NODELIST" | ||
| python -m apps.grpo.main --config=apps/grpo/multihost.yaml | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from __future__ import annotations | ||
|
||
|
|
||
| import asyncio | ||
| import logging | ||
| import os | ||
|
|
@@ -92,6 +94,7 @@ class EngineConfig(EngineArgs): | |
| tensor_parallel_size: int = 1 | ||
| pipeline_parallel_size: int = 1 | ||
| enforce_eager: bool = False | ||
| enable_expert_parallel: bool = False | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, d: Mapping): | ||
|
|
@@ -100,6 +103,16 @@ def from_dict(cls, d: Mapping): | |
| valid_args = {k: v for k, v in d.items() if k in all_fields} | ||
| return cls(**valid_args) | ||
|
|
||
| @classmethod | ||
| def as_engine_args(cls, config: Mapping | EngineConfig) -> EngineConfig: | ||
| if isinstance(config, Mapping): | ||
| config = EngineConfig.from_dict(config) | ||
|
|
||
| # Original method returns False when not run in the main thread | ||
| config._is_v1_supported_oracle = lambda *_: True | ||
| # Build Config | ||
| return config.create_engine_config(UsageContext.LLM_CLASS) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Policy(PolicyInterface): | ||
|
|
@@ -138,9 +151,15 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
| # automatically. | ||
| worker_procs = await get_proc_mesh(process_config=process_config) | ||
|
|
||
| # TODO - we will want to ensure colocation with workers | ||
| # TODO - issues/144 we will want to ensure colocation with workers | ||
| # We're currently locating the Policy on the local host proc mesh | ||
| # vLLM initialization without setting env variables at proc_mesh creation | ||
| # level leads to issues. | ||
| # Once we can create multiple proc meshes on a host mesh, we can ensure | ||
| # host colocation | ||
| policy_proc_config = copy(process_config) | ||
| policy_proc_config.num_procs = 1 | ||
| policy_proc_config.num_hosts = None | ||
| policy_proc_config.with_gpus = False | ||
|
|
||
| policy_proc = await get_proc_mesh(process_config=policy_proc_config) | ||
|
|
@@ -196,7 +215,7 @@ async def setup(self): | |
|
|
||
| self.request_id = 0 | ||
| self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} | ||
| self.vllm_args = await self.policy_worker.get_vllm_args.choose() | ||
| self.vllm_args = EngineConfig.as_engine_args(self.engine_config) | ||
|
||
|
|
||
| # Setup sampling params | ||
| self.sampling_params = get_default_sampling_params( | ||
|
|
@@ -382,13 +401,7 @@ def __post_init__(self): | |
| - all LLM generate methods, verify against LLM inputs | ||
| - all executor methods verify no changes | ||
| """ | ||
| if isinstance(self.vllm_args, Mapping): | ||
| self.vllm_args = EngineConfig.from_dict(self.vllm_args) | ||
|
|
||
| # Original method returns False when not run in the main thread | ||
| self.vllm_args._is_v1_supported_oracle = lambda *_: True | ||
| # Build Config | ||
| self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) | ||
| self.vllm_args = EngineConfig.as_engine_args(self.vllm_args) | ||
|
|
||
| @endpoint | ||
| async def setup(self, store: MultiProcessStore = None): | ||
|
|
@@ -476,10 +489,6 @@ async def setup_kv_cache(self): | |
| self.worker.initialize_cache(kv_cache_config.num_blocks, 0) | ||
| return kv_cache_config | ||
|
|
||
| @endpoint | ||
| async def get_vllm_args(self): | ||
| return self.vllm_args | ||
|
|
||
| @endpoint | ||
| async def _get_model_params(self) -> Dict[str, torch.Tensor]: | ||
| model = self.worker.model_runner.model | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be awesome!