Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@
load_tensor_from_dcp,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.controller import (
ForgeActor,
get_proc_mesh,
host_mesh_from_proc,
stop_proc_mesh,
)
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import TORCHSTORE_USE_RDMA
Expand Down Expand Up @@ -139,17 +144,22 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
mesh_name=cls.mesh_name,
)

# TODO - issues/144 we will want to ensure colocation with workers
# We're currently locating the Generator on the local host proc mesh
# vLLM initialization without setting env variables at proc_mesh creation
# level leads to issues. Once we can create multiple proc meshes on a host mesh,
# we can ensure host colocation
# First, spawn the worker processes which may or may not be
# on remote hosts.
worker_procs = await get_proc_mesh(process_config=process_config)

# Then, grab a single host from the workers...
host_mesh = await host_mesh_from_proc(worker_procs)
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
host_mesh = host_mesh.slice(**singleton_slice)

# We ask the provisioner for a single process on a single host
generator_proc_config = copy(process_config)
generator_proc_config.procs = 1
generator_proc_config.hosts = None
generator_proc_config.with_gpus = False
generator_proc = await get_proc_mesh(process_config=generator_proc_config)

generator_proc = await get_proc_mesh(
process_config=generator_proc_config, host_mesh=host_mesh
)
# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
generator = generator_proc.spawn(
Expand All @@ -159,7 +169,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
**kwargs,
)

worker_procs = await get_proc_mesh(process_config=process_config)
vllm_config = (
await generator.get_vllm_config.call_one()
) # Config should be the same across all actors
Expand Down
15 changes: 1 addition & 14 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import logging
import math
import os
import shutil

Expand All @@ -18,7 +17,7 @@
import torch.distributed.checkpoint as dcp
import torchstore as ts

from monarch.actor import current_rank, current_size, endpoint
from monarch.actor import endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torchtitan.config.job_config import (
Expand Down Expand Up @@ -163,19 +162,7 @@ def __post_init__(self):
self.step = 1 # fragile contract.
self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
self.size = math.prod(current_size().values())

env = {
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
"LOCAL_WORLD_SIZE": str(self.size),
"GROUP_RANK": str(self.size),
"GROUP_WORLD_SIZE": str(self.size),
"ROLE_RANK": str(self.rank),
"ROLE_WORLD_SIZE": str(self.size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self.size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
Expand Down
10 changes: 9 additions & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

from monarch.tools import commands

from forge.controller.launcher import BaseLauncher, get_launcher
from monarch.utils import setup_env_for_distributed

from forge.controller.launcher import BaseLauncher, get_launcher
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
from forge.types import ProcessConfig, ProvisionerConfig

Expand Down Expand Up @@ -283,6 +284,13 @@ def bootstrap(env: dict[str, str]):
bootstrap=functools.partial(bootstrap, env=env_vars),
)

# Set up environment variables for PyTorch distributed...
await setup_env_for_distributed(
procs,
master_addr=addr,
master_port=port,
)

if is_remote:
await self.launcher.remote_setup(procs)

Expand Down
Loading