Skip to content
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ async def main():
# ---- Setup services ---- #
default_service_cfg = ServiceConfig(
procs_per_replica=1,
gpus_per_replica=1,
num_replicas=1,
)

Expand All @@ -363,7 +364,6 @@ async def main():
num_workers=1,
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
available_devices="3",
),
)

Expand Down
2 changes: 2 additions & 0 deletions apps/rl/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ trainer:
processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_gpus: 4
num_procs: 4

optimizer:
Expand Down Expand Up @@ -65,6 +66,7 @@ replay_buffer:
processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_gpus: 0
num_procs: 1

# policy:
Expand Down
3 changes: 2 additions & 1 deletion apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ comm:
model:
name: llama3
flavor: 8B
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
tokenizer_path: /tmp/Llama-3.1-8B-Instruct

processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_procs: 8
num_gpus: 8

optimizer:
name: AdamW
Expand Down
2 changes: 1 addition & 1 deletion apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

"""To run:

python -m apps.sft.main --config apps/sft/llama3_8b.yaml
python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml

"""

Expand Down
15 changes: 12 additions & 3 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
from .interface import ServiceInterface, Session, SessionContext
from .proc_mesh import get_proc_mesh, spawn_actors
from .service import Service, ServiceConfig
from .spawn import spawn_service
from .service import (
Replica,
ReplicaMetrics,
Service,
ServiceConfig,
ServiceInterface,
Session,
SessionContext,
spawn_service,
)

__all__ = [
"Service",
Expand All @@ -19,4 +26,6 @@
"spawn_actors",
"get_proc_mesh",
"ForgeActor",
"Replica",
"ReplicaMetrics",
]
12 changes: 12 additions & 0 deletions src/forge/controller/custom_actors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .gpu_manager import get_gpu_ids, release_gpus

__all__ = [
"get_gpu_ids",
"release_gpus",
]
75 changes: 75 additions & 0 deletions src/forge/controller/custom_actors/gpu_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Implements an actor responsible for tracking and assigning GPU devices on HostMesh."""

import logging

from monarch.actor import ActorError, endpoint, get_or_spawn_controller

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class GpuManager(ForgeActor):
"""An actor that tracks and assigns GPU devices on given HostMeshes."""

def __init__(self):
# TODO - extend this to support multiple HostMeshes too
self.available_gpus = set(range(0, 8))

@endpoint
def get_available_gpus(self) -> list[str]:
"""Returns a list of available GPU devices."""
return [str(gpu) for gpu in self.available_gpus]

@endpoint
def get_gpus(self, num_gpus: int) -> list[str]:
"""Assigns GPU devices."""
if num_gpus > len(self.available_gpus):
raise RuntimeError("Not enough GPUs available")
gpus = list(self.available_gpus)[:num_gpus]
self.available_gpus -= set(gpus)
return [str(gpu) for gpu in gpus]

@endpoint
def release_gpus(self, gpu_ids: list[str]) -> None:
"""Releases the given GPU devices."""
for gpu_id in gpu_ids:
self.available_gpus.add(int(gpu_id))

def __repr__(self) -> str:
return "GpuManager"


async def get_gpu_manager() -> GpuManager:
"""Gets the singleton GPU manager actor."""
try:
return await get_or_spawn_controller("gpu_manager", GpuManager)
except ActorError as e:
raise e.exception from e


async def get_gpu_ids(num_gpus: int) -> list[str]:
"""Gets GPU IDs for the given number of GPUs."""
try:
gpu_manager = await get_or_spawn_controller("gpu_manager", GpuManager)
return await gpu_manager.get_gpus.call_one(num_gpus)
except ActorError as e:
# Raise the underlying error instead of the Monarch error
raise e.exception from e


async def release_gpus(gpu_ids: list[str]) -> None:
"""Releases the given GPU IDs."""
try:
gpu_manager = await get_or_spawn_controller("gpu_manager", GpuManager)
await gpu_manager.release_gpus.call_one(gpu_ids)
except ActorError as e:
# Raise the underlying error instead of the Monarch error
raise e.exception from e
20 changes: 20 additions & 0 deletions src/forge/controller/custom_actors/service_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Implements an actor that tracks all services runinng in the workload."""

from monarch.actor import endpoint

from forge.controller import ForgeActor


class ServiceRegistry(ForgeActor):
def __init__(self):
pass

@endpoint
def register(self):
pass
51 changes: 36 additions & 15 deletions src/forge/controller/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@

import os
import socket
from functools import partial

from monarch.actor import proc_mesh, ProcMesh
from monarch.tools import commands
from monarch.tools.config import Config
from omegaconf import DictConfig

from forge.controller import ForgeActor

from forge.controller.custom_actors.gpu_manager import get_gpu_ids
from forge.types import ProcessConfig

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,27 +51,50 @@ async def spawn_actors(
set_address: bool = False,
):
"""Setup process Mesh and spawn Actors."""
mesh = await get_proc_mesh(processes, set_address)
mesh = await get_proc_mesh(processes)
actors = await mesh.spawn(name, actor_cls, **cfg)
actors.mesh = mesh
return actors


async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> ProcMesh:
env = None
if set_address:
env = {
"MASTER_ADDR": str(socket.gethostname()),
"MASTER_PORT": str(_find_free_port()),
}
async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh:
"""Returns a proc mesh with the given process config."""

# TODO - modify this to work with multi-host
env = {
"MASTER_ADDR": str(socket.gethostname()),
"MASTER_PORT": str(_find_free_port()),
}

def _setup_env(env: dict[str, str]):
"""Sets up the environment on proc mesh creation."""
for k, v in env.items():
os.environ[k] = v

if process_config.scheduler == "local":
if process_config.num_hosts != 1:
raise ValueError("Local scheduler only supports 1 host")
return await proc_mesh(gpus=process_config.num_procs, env=env)

if process_config.num_gpus > 0:
gpu_ids = await get_gpu_ids(process_config.num_gpus)
env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))

# TODO - update to use this_host() whenever it supports
# being run wihtin actors:
# AttributeError: NYI: attempting to get ProcMesh attribute `slice` on object that's
# actually a ProcMeshRef
# return this_host().spawn_procs(
# per_host={"procs": process_config.num_procs},
# bootstrap=partial(_setup_env, env=env),
# )
return proc_mesh(gpus=process_config.num_procs, env=env)
elif process_config.scheduler == "mast":
if not MAST_SUPPORTED:
raise ValueError("MAST is not supported on this platform")

if process_config.num_gpus != 0:
raise ValueError("NYI - need to add HostMesh tracking in GpuManager")

logging.info("Scheduling on MAST with: ", process_config)
jobname = f"monarch-{getpass.getuser()}"
config = Config(
Expand Down Expand Up @@ -104,12 +130,7 @@ async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> Pro
)
alloc = await allocator.allocate(AllocSpec(constraints, **mesh_dimensions))
if env:

def setup(): # noqa: FB811
for k, v in env.items():
os.environ[k] = v

p = await ProcMesh.from_alloc(alloc, setup=setup)
p = await ProcMesh.from_alloc(alloc, setup=partial(_setup_env, env=env))
else:
p = await ProcMesh.from_alloc(alloc)
await p.logging_option(stream_to_client=True, aggregate_window_sec=3)
Expand Down
23 changes: 23 additions & 0 deletions src/forge/controller/service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .interface import ServiceInterface, Session, SessionContext
from .metrics import ServiceMetrics
from .replica import Replica, ReplicaMetrics
from .service import Service, ServiceConfig
from .spawn import spawn_service

__all__ = [
"Service",
"ServiceConfig",
"spawn_service",
"ServiceInterface",
"Session",
"SessionContext",
"ServiceMetrics",
"Replica",
"ReplicaMetrics",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import dataclass, field
from typing import Dict, List

from forge.controller.replica import ReplicaMetrics
from forge.controller.service.replica import ReplicaMetrics


# TODO - tie this into metrics logger when it exists.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@

from monarch.actor import Actor, endpoint

from forge.controller.interface import _session_context, Session
from forge.controller.metrics import ServiceMetrics
from forge.controller.replica import Replica, ServiceRequest
from forge.controller.service.interface import _session_context, Session

from forge.controller.service.metrics import ServiceMetrics
from forge.controller.service.replica import Replica, ServiceRequest
from forge.types import ServiceConfig

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from monarch.actor import Actor, proc_mesh

from forge.controller import Service, ServiceConfig
from forge.controller.interface import ServiceInterface
from forge.controller.service import Service, ServiceConfig

from forge.controller.service.interface import ServiceInterface

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down
3 changes: 3 additions & 0 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ProcessConfig:

scheduler: Literal["mast", "local"] = "local"
num_procs: int = 1
num_gpus: int = 0
num_hosts: int = 1
# The following is mast specific.
oncall: str = "torchtune"
Expand All @@ -105,6 +106,7 @@ class ServiceConfig:
"""A service config."""

procs_per_replica: int
gpus_per_replica: int
num_replicas: int
num_hosts: int = 1
scheduler: Literal["mast", "local"] = "local"
Expand All @@ -125,6 +127,7 @@ def to_process_config(self) -> ProcessConfig:
return ProcessConfig(
scheduler=self.scheduler,
num_procs=self.procs_per_replica,
num_gpus=self.gpus_per_replica,
num_hosts=self.num_hosts,
oncall=self.oncall,
identity=self.identity,
Expand Down
Loading
Loading