Skip to content

Commit aa59857

Browse files
allenwang28init27Jack-KhuucasteryhDNXie
authored
Use monarch's distributed setup utility and colocate vLLM with its workers (#409)
Co-authored-by: Sanyam Bhutani <[email protected]> Co-authored-by: Jack-Khuu <[email protected]> Co-authored-by: casteryh <[email protected]> Co-authored-by: Danning XIE <[email protected]> Co-authored-by: Jiyue Wang <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Philip Bontrager <[email protected]>
1 parent 633b219 commit aa59857

File tree

3 files changed

+31
-29
lines changed

3 files changed

+31
-29
lines changed

src/forge/actors/generator.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@
4848
load_tensor_from_dcp,
4949
)
5050

51-
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
51+
from forge.controller import (
52+
ForgeActor,
53+
get_proc_mesh,
54+
host_mesh_from_proc,
55+
stop_proc_mesh,
56+
)
5257
from forge.data_models.completion import Completion
5358
from forge.data_models.prompt import to_prompt
5459
from forge.env import TORCHSTORE_USE_RDMA
@@ -139,17 +144,22 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
139144
mesh_name=cls.mesh_name,
140145
)
141146

142-
# TODO - issues/144 we will want to ensure colocation with workers
143-
# We're currently locating the Generator on the local host proc mesh
144-
# vLLM initialization without setting env variables at proc_mesh creation
145-
# level leads to issues. Once we can create multiple proc meshes on a host mesh,
146-
# we can ensure host colocation
147+
# First, spawn the worker processes which may or may not be
148+
# on remote hosts.
149+
worker_procs = await get_proc_mesh(process_config=process_config)
150+
151+
# Then, grab a single host from the workers...
152+
host_mesh = await host_mesh_from_proc(worker_procs)
153+
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
154+
host_mesh = host_mesh.slice(**singleton_slice)
155+
156+
# We ask the provisioner for a single process on a single host
147157
generator_proc_config = copy(process_config)
148158
generator_proc_config.procs = 1
149-
generator_proc_config.hosts = None
150159
generator_proc_config.with_gpus = False
151-
generator_proc = await get_proc_mesh(process_config=generator_proc_config)
152-
160+
generator_proc = await get_proc_mesh(
161+
process_config=generator_proc_config, host_mesh=host_mesh
162+
)
153163
# TODO - expand support so name can stick within kwargs
154164
actor_name = kwargs.pop("name", cls.__name__)
155165
generator = generator_proc.spawn(
@@ -159,7 +169,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
159169
**kwargs,
160170
)
161171

162-
worker_procs = await get_proc_mesh(process_config=process_config)
163172
vllm_config = (
164173
await generator.get_vllm_config.call_one()
165174
) # Config should be the same across all actors

src/forge/actors/trainer.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
import math
98
import os
109
import shutil
1110

@@ -18,7 +17,7 @@
1817
import torch.distributed.checkpoint as dcp
1918
import torchstore as ts
2019

21-
from monarch.actor import current_rank, current_size, endpoint
20+
from monarch.actor import endpoint
2221
from torch import Tensor
2322
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
2423
from torchtitan.config.job_config import (
@@ -163,22 +162,7 @@ def __post_init__(self):
163162
self.step = 1 # fragile contract.
164163
self.num_training_steps = self.training.steps
165164
self.gradient_accumulation_steps = 1
166-
self.rank = current_rank().rank
167-
self.size = math.prod(current_size().values())
168-
169-
env = {
170-
"RANK": str(self.rank),
171-
"LOCAL_RANK": str(self.rank),
172-
"LOCAL_WORLD_SIZE": str(self.size),
173-
"GROUP_RANK": str(self.size),
174-
"GROUP_WORLD_SIZE": str(self.size),
175-
"ROLE_RANK": str(self.rank),
176-
"ROLE_WORLD_SIZE": str(self.size),
177-
"ROLE_NAME": "rank",
178-
"WORLD_SIZE": str(self.size),
179-
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
180-
}
181-
os.environ.update(env)
165+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
182166
logger.info("Compiling loss")
183167
self.loss = torch.compile(self.loss)
184168

src/forge/controller/provisioner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020

2121
from monarch.tools import commands
2222

23-
from forge.controller.launcher import BaseLauncher, get_launcher
23+
from monarch.utils import setup_env_for_distributed
2424

25+
from forge.controller.launcher import BaseLauncher, get_launcher
2526
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
2627
from forge.types import ProcessConfig, ProvisionerConfig
2728

@@ -283,6 +284,14 @@ def bootstrap(env: dict[str, str]):
283284
bootstrap=functools.partial(bootstrap, env=env_vars),
284285
)
285286

287+
if with_gpus:
288+
# Set up environment variables for PyTorch distributed...
289+
await setup_env_for_distributed(
290+
procs,
291+
master_addr=addr,
292+
master_port=port,
293+
)
294+
286295
if is_remote:
287296
await self.launcher.remote_setup(procs)
288297

0 commit comments

Comments
 (0)