Skip to content

Commit ce91430

Browse files
allenwang28Allen Wang
andauthored
Enables dynamic GPU allocation for local workloads (#91)
* initial commit * add back spawn * stash * park * add gpu resource management * add gpu resource management * update test apis * stash * sft v2 works again * renames, adds stop capability * some updates * typo fix * fix test * missing import * no nested submodule * check * num_gpus => with_gpus * proc mesh update --------- Co-authored-by: Allen Wang <[email protected]>
1 parent ddd0794 commit ce91430

File tree

19 files changed

+454
-58
lines changed

19 files changed

+454
-58
lines changed

apps/grpo/main.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from datasets import load_dataset
1414
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
1515
from forge.actors.replay_buffer import ReplayBuffer
16-
from forge.controller import ServiceConfig, spawn_service
1716
from forge.controller.actor import ForgeActor
17+
from forge.controller.service import ServiceConfig, spawn_service
1818
from forge.data.rewards import MathReward, ThinkingReward
1919
from forge.util.metric_logging import get_metric_logger
2020
from monarch.actor import endpoint
@@ -351,40 +351,33 @@ async def main():
351351
)
352352

353353
# ---- Setup services ---- #
354-
default_service_cfg = ServiceConfig(
355-
procs_per_replica=1,
356-
num_replicas=1,
357-
)
358-
359354
policy = await spawn_service(
360-
default_service_cfg,
355+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
361356
Policy,
362357
PolicyConfig(
363358
num_workers=1,
364359
worker_params=WorkerConfig(model=model),
365360
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
366-
available_devices="3",
367361
),
368362
)
369363

370364
trainer = await spawn_service(
371-
default_service_cfg,
365+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
372366
Trainer,
373367
learning_rate=1e-5,
374368
beta=0.1,
375369
model_name=model,
376-
device=torch.device("cuda:1"),
377370
)
378371

379372
replay_buffer = await spawn_service(
380-
default_service_cfg,
373+
ServiceConfig(procs_per_replica=1, num_replicas=1),
381374
ReplayBuffer,
382375
batch_size=4,
383376
max_policy_age=1,
384377
)
385378

386379
dataloader = await spawn_service(
387-
default_service_cfg,
380+
ServiceConfig(procs_per_replica=1, num_replicas=1),
388381
DatasetActor,
389382
"openai/gsm8k",
390383
"main",
@@ -393,21 +386,20 @@ async def main():
393386
)
394387

395388
compute_advantages = await spawn_service(
396-
default_service_cfg,
389+
ServiceConfig(procs_per_replica=1, num_replicas=1),
397390
ComputeAdvantages,
398391
gamma=0.99,
399392
lambda_=0.95,
400393
)
401394

402395
ref_model = await spawn_service(
403-
default_service_cfg,
396+
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
404397
RefModel,
405398
model_name=model,
406-
device=torch.device("cuda:2"),
407399
)
408400

409401
reward_actor = await spawn_service(
410-
default_service_cfg,
402+
ServiceConfig(procs_per_replica=1, num_replicas=1),
411403
RewardActor,
412404
reward_functions=[MathReward(), ThinkingReward()],
413405
)

apps/rl/llama3_8b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ trainer:
1818
processes:
1919
scheduler: local # local | mast (not supported yet)
2020
num_hosts: 1
21+
num_gpus: 4
2122
num_procs: 4
2223

2324
optimizer:
@@ -65,6 +66,7 @@ replay_buffer:
6566
processes:
6667
scheduler: local # local | mast (not supported yet)
6768
num_hosts: 1
69+
num_gpus: 0
6870
num_procs: 1
6971

7072
# policy:

apps/sft_v2/llama3_8b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ comm:
1414
model:
1515
name: llama3
1616
flavor: 8B
17-
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
17+
tokenizer_path: /tmp/Llama-3.1-8B-Instruct
1818

1919
processes:
2020
scheduler: local # local | mast (not supported yet)
2121
num_hosts: 1
2222
num_procs: 8
23+
num_gpus: 8
2324

2425
optimizer:
2526
name: AdamW

apps/sft_v2/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
"""To run:
88
9-
python -m apps.sft.main --config apps/sft/llama3_8b.yaml
9+
python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml
1010
1111
"""
1212

apps/vllm/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from typing import List
1717

1818
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
19-
from forge.controller.service import ServiceConfig
20-
from forge.controller.spawn import spawn_service
19+
from forge.controller.service import ServiceConfig, spawn_service
2120
from vllm.outputs import CompletionOutput
2221

2322

src/forge/controller/__init__.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
from .actor import ForgeActor
7-
from .interface import ServiceInterface, Session, SessionContext
8-
from .proc_mesh import get_proc_mesh, spawn_actors
9-
from .service import Service, ServiceConfig
10-
from .spawn import spawn_service
7+
from .proc_mesh import get_proc_mesh, spawn_actors, stop_proc_mesh
118

129
__all__ = [
13-
"Service",
14-
"ServiceConfig",
15-
"ServiceInterface",
16-
"Session",
17-
"SessionContext",
18-
"spawn_service",
1910
"spawn_actors",
11+
"stop_proc_mesh",
2012
"get_proc_mesh",
2113
"ForgeActor",
2214
]

src/forge/controller/proc_mesh.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111

1212
import os
1313
import socket
14+
from functools import partial
1415

1516
from monarch.actor import proc_mesh, ProcMesh
1617
from monarch.tools import commands
1718
from monarch.tools.config import Config
1819
from omegaconf import DictConfig
1920

2021
from forge.controller import ForgeActor
22+
23+
from forge.controller.system_controllers.gpu_manager import get_gpu_ids, release_gpus
2124
from forge.types import ProcessConfig
2225

2326
logger: logging.Logger = logging.getLogger(__name__)
@@ -48,27 +51,52 @@ async def spawn_actors(
4851
set_address: bool = False,
4952
):
5053
"""Setup process Mesh and spawn Actors."""
51-
mesh = await get_proc_mesh(processes, set_address)
54+
mesh = await get_proc_mesh(processes)
5255
actors = await mesh.spawn(name, actor_cls, **cfg)
5356
actors.mesh = mesh
5457
return actors
5558

5659

57-
async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> ProcMesh:
58-
env = None
59-
if set_address:
60-
env = {
61-
"MASTER_ADDR": str(socket.gethostname()),
62-
"MASTER_PORT": str(_find_free_port()),
63-
}
60+
async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh:
61+
"""Returns a proc mesh with the given process config."""
62+
# TODO - modify this to work with multi-host
63+
env = {
64+
"MASTER_ADDR": str(socket.gethostname()),
65+
"MASTER_PORT": str(_find_free_port()),
66+
}
67+
gpu_ids = None
68+
69+
def _setup_env(env: dict[str, str]):
70+
"""Sets up the environment on proc mesh creation."""
71+
for k, v in env.items():
72+
os.environ[k] = v
73+
6474
if process_config.scheduler == "local":
6575
if process_config.num_hosts != 1:
6676
raise ValueError("Local scheduler only supports 1 host")
67-
return await proc_mesh(gpus=process_config.num_procs, env=env)
77+
78+
if process_config.with_gpus:
79+
gpu_ids = await get_gpu_ids(process_config.num_procs)
80+
env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
81+
82+
# TODO - update to use this_host() whenever it supports
83+
# being run within actors:
84+
# AttributeError: NYI: attempting to get ProcMesh attribute `slice` on object that's
85+
# actually a ProcMeshRef
86+
# return this_host().spawn_procs(
87+
# per_host={"procs": process_config.num_procs},
88+
# bootstrap=partial(_setup_env, env=env),
89+
# )
90+
m = proc_mesh(gpus=process_config.num_procs, env=env)
91+
m._gpu_ids = gpu_ids
92+
return m
6893
elif process_config.scheduler == "mast":
6994
if not MAST_SUPPORTED:
7095
raise ValueError("MAST is not supported on this platform")
7196

97+
if process_config.with_gpus:
98+
raise ValueError("NYI - need to add HostMesh tracking in GpuManager")
99+
72100
logging.info("Scheduling on MAST with: ", process_config)
73101
jobname = f"monarch-{getpass.getuser()}"
74102
config = Config(
@@ -104,12 +132,7 @@ async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> Pro
104132
)
105133
alloc = await allocator.allocate(AllocSpec(constraints, **mesh_dimensions))
106134
if env:
107-
108-
def setup(): # noqa: FB811
109-
for k, v in env.items():
110-
os.environ[k] = v
111-
112-
p = await ProcMesh.from_alloc(alloc, setup=setup)
135+
p = await ProcMesh.from_alloc(alloc, setup=partial(_setup_env, env=env))
113136
else:
114137
p = await ProcMesh.from_alloc(alloc)
115138
await p.logging_option(stream_to_client=True, aggregate_window_sec=3)
@@ -118,6 +141,15 @@ def setup(): # noqa: FB811
118141
raise ValueError("Unsupported scheduler: {}".format(process_config.scheduler))
119142

120143

144+
async def stop_proc_mesh(mesh: ProcMesh) -> None:
145+
"""Stops the given proc mesh."""
146+
if hasattr(mesh, "_gpu_ids") and mesh._gpu_ids is not None:
147+
gpu_ids = mesh._gpu_ids
148+
logger.debug("Releasing GPUs: %s", gpu_ids)
149+
await release_gpus(gpu_ids)
150+
await mesh.stop()
151+
152+
121153
def _find_free_port() -> int:
122154
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
123155
s.bind(("localhost", 0))
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .interface import ServiceInterface, Session, SessionContext
8+
from .metrics import ServiceMetrics
9+
from .replica import Replica, ReplicaMetrics
10+
from .service import Service, ServiceConfig
11+
from .spawn import spawn_service
12+
13+
__all__ = [
14+
"Replica",
15+
"ReplicaMetrics",
16+
"Service",
17+
"ServiceConfig",
18+
"ServiceInterface",
19+
"ServiceMetrics",
20+
"Session",
21+
"SessionContext",
22+
"spawn_service",
23+
]
File renamed without changes.

src/forge/controller/metrics.py renamed to src/forge/controller/service/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dataclasses import dataclass, field
1515
from typing import Dict, List
1616

17-
from forge.controller.replica import ReplicaMetrics
17+
from forge.controller.service.replica import ReplicaMetrics
1818

1919

2020
# TODO - tie this into metrics logger when it exists.

0 commit comments

Comments
 (0)