Skip to content

Commit 8e2c06b

Browse files
author
Allen Wang
committed
gpu manager uses devices
1 parent 3e0ab2e commit 8e2c06b

File tree

2 files changed

+112
-19
lines changed

2 files changed

+112
-19
lines changed

src/forge/controller/provisioner.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import socket
1313
import uuid
1414

15+
import torch
16+
1517
from monarch._src.actor.actor_mesh import ActorMesh
1618
from monarch._src.actor.shape import Extent
1719

@@ -40,8 +42,14 @@ class _RemoteInfoFetcher(Actor):
4042
"""An actor responsible for getting remote host information."""
4143

4244
@endpoint
43-
def get_info(self) -> tuple[str, str]:
44-
return socket.gethostname(), _get_port()
45+
def get_info(self) -> tuple[str, str, int]:
46+
"""Returns hostname, port, and GPU count."""
47+
try:
48+
gpu_count = torch.cuda.device_count()
49+
except Exception:
50+
# If torch is not available or CUDA is not available, assume no GPUs
51+
gpu_count = 0
52+
return socket.gethostname(), _get_port(), gpu_count
4553

4654

4755
class EnvSetter(Actor):
@@ -77,8 +85,8 @@ def set_env(self, env_vars: dict[str, str]):
7785
os.environ[k] = v
7886

7987

80-
async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
81-
"""Returns the host name and port of the host mesh."""
88+
async def get_host_info(host_mesh: HostMesh) -> tuple[str, str, int]:
89+
"""Returns the host name, port, and GPU count of the host mesh."""
8290
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
8391
fetcher = throwaway_procs.spawn("_fetcher", _RemoteInfoFetcher)
8492

@@ -88,11 +96,11 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
8896
fetcher = fetcher.slice(**singleton_slice)
8997
# Fetcher should be a singleton at this point - call_one() will fail otherwise
9098

91-
host, port = await fetcher.get_info.call_one()
99+
host, port, gpu_count = await fetcher.get_info.call_one()
92100

93101
# Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
94102
# await throwaway_procs.stop()
95-
return host, port
103+
return host, port, gpu_count
96104

97105

98106
async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
@@ -110,14 +118,37 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
110118

111119

112120
class GpuManager:
113-
"""Tracks and assigns GPU devices on a host."""
121+
"""Tracks and assigns GPU devices on a host.
122+
123+
Args:
124+
available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
125+
max_device_count: Maximum number of GPU devices on this host. Defaults to 8.
126+
127+
"""
114128

115-
def __init__(self, available_devices: set[int] | None = None):
129+
def __init__(
130+
self, available_devices: set[int] | None = None, max_device_count: int = 8
131+
):
116132
if available_devices is None:
117-
available_devices = set(range(0, 8))
118-
assert all(isinstance(x, int) for x in available_devices)
119-
assert all(x >= 0 and x < 8 for x in available_devices)
133+
available_devices = set(range(0, max_device_count))
134+
else:
135+
# Validate types first
136+
assert all(
137+
isinstance(x, int) for x in available_devices
138+
), f"All device IDs must be integers, got: {available_devices}"
139+
# When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
140+
# adjust max_device_count to accommodate the highest device ID
141+
if available_devices:
142+
max_device_count = max(max(available_devices) + 1, max_device_count)
143+
144+
assert all(
145+
isinstance(x, int) for x in available_devices
146+
), f"All device IDs must be integers, got: {available_devices}"
147+
assert all(
148+
x >= 0 for x in available_devices
149+
), f"All device IDs must be non-negative, got: {available_devices}"
120150
self.available_gpus = available_devices
151+
self.max_device_count = max_device_count
121152

122153
def get_available_gpus(self) -> list[str]:
123154
"""Returns a list of available GPU devices."""
@@ -166,8 +197,18 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
166197
f"Invalid CUDA_VISIBLE_DEVICES format: '{cuda_visible_devices}'. "
167198
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
168199
) from e
200+
201+
# Get the actual GPU count for the local host
202+
try:
203+
local_gpu_count = torch.cuda.device_count()
204+
except Exception:
205+
# If torch is not available or CUDA is not available, assume no GPUs
206+
local_gpu_count = 0
207+
169208
self._host_gpu_map = {
170-
self._this_host_id: GpuManager(available_local_devices),
209+
self._this_host_id: GpuManager(
210+
available_local_devices, max_device_count=local_gpu_count
211+
),
171212
}
172213
self._proc_host_map = {}
173214
self._host_mesh_map = {}
@@ -272,9 +313,18 @@ async def get_proc_mesh(
272313
num_hosts=num_hosts,
273314
)
274315
host_id = uuid.uuid1()
275-
gpu_manager = GpuManager()
316+
# Get host info including GPU count from the remote host
317+
host_addr, host_port, remote_gpu_count = await get_host_info(
318+
host_mesh
319+
)
320+
gpu_manager = GpuManager(max_device_count=remote_gpu_count)
276321
self._host_gpu_map[host_id] = gpu_manager
277322
host_mesh._host_id = host_id
323+
# Use the fetched addr/port if not explicitly provided
324+
if addr is None:
325+
addr = host_addr
326+
if port is None:
327+
port = host_port
278328
else:
279329
host_id = host_mesh._host_id
280330
gpu_manager = self._host_gpu_map[host_id]
@@ -286,7 +336,7 @@ async def get_proc_mesh(
286336

287337
if with_gpus:
288338
if not addr or not port:
289-
addr, port = await get_remote_info(host_mesh)
339+
addr, port, _ = await get_host_info(host_mesh)
290340
gpu_ids = gpu_manager.get_gpus(num_procs)
291341

292342
env_vars["MASTER_ADDR"] = addr

tests/unit_tests/test_provisioner.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ def test_gpu_manager_invalid_device_range(self):
4545
with pytest.raises(AssertionError):
4646
GpuManager(available_devices={-1}) # Negative device
4747

48-
with pytest.raises(AssertionError):
49-
GpuManager(available_devices={8}) # Device >= 8
50-
5148
with pytest.raises(AssertionError):
5249
GpuManager(available_devices={"0"}) # String instead of int
5350

@@ -90,7 +87,8 @@ class TestProvisionerCudaVisibleDevices:
9087
"""Test Provisioner's handling of CUDA_VISIBLE_DEVICES environment variable."""
9188

9289
@mock.patch.dict(os.environ, {}, clear=True)
93-
def test_provisioner_no_cuda_visible_devices(self):
90+
@mock.patch("torch.cuda.device_count", return_value=8)
91+
def test_provisioner_no_cuda_visible_devices(self, mock_device_count):
9492
"""Test Provisioner when CUDA_VISIBLE_DEVICES is not set."""
9593
provisioner = Provisioner()
9694

@@ -135,7 +133,8 @@ def test_provisioner_duplicate_gpu_ids(self):
135133
assert len(available) == 3
136134

137135
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": ""}, clear=True)
138-
def test_provisioner_empty_cuda_visible_devices(self):
136+
@mock.patch("torch.cuda.device_count", return_value=8)
137+
def test_provisioner_empty_cuda_visible_devices(self, mock_device_count):
139138
"""Test Provisioner with empty CUDA_VISIBLE_DEVICES."""
140139
provisioner = Provisioner()
141140

@@ -245,3 +244,47 @@ def test_single_gpu_scenario(self):
245244
# Release and verify
246245
local_gpu_manager.release_gpus(allocated)
247246
assert local_gpu_manager.get_available_gpus() == ["0"]
247+
248+
249+
class TestDynamicGpuDetection:
250+
"""Test dynamic GPU detection using torch.cuda.device_count()."""
251+
252+
@mock.patch.dict(os.environ, {}, clear=True)
253+
@mock.patch("torch.cuda.device_count", return_value=4)
254+
def test_provisioner_with_4_gpus(self, mock_device_count):
255+
"""Test Provisioner detects 4 GPUs when torch.cuda.device_count() returns 4."""
256+
provisioner = Provisioner()
257+
258+
local_gpu_manager = provisioner._host_gpu_map[provisioner._this_host_id]
259+
available = local_gpu_manager.get_available_gpus()
260+
assert sorted(available) == ["0", "1", "2", "3"]
261+
assert len(available) == 4
262+
assert local_gpu_manager.max_device_count == 4
263+
264+
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,2,4"}, clear=True)
265+
@mock.patch("torch.cuda.device_count", return_value=8)
266+
def test_cuda_visible_devices_with_detected_gpus(self, mock_device_count):
267+
"""Test that CUDA_VISIBLE_DEVICES works correctly with detected GPU count."""
268+
provisioner = Provisioner()
269+
270+
local_gpu_manager = provisioner._host_gpu_map[provisioner._this_host_id]
271+
available = local_gpu_manager.get_available_gpus()
272+
# Should use CUDA_VISIBLE_DEVICES, not all 8 detected GPUs
273+
assert sorted(available) == ["0", "2", "4"]
274+
assert len(available) == 3
275+
# max_device_count should still be 8 from detection
276+
assert local_gpu_manager.max_device_count == 8
277+
278+
@mock.patch.dict(os.environ, {}, clear=True)
279+
@mock.patch(
280+
"torch.cuda.device_count", side_effect=RuntimeError("CUDA not available")
281+
)
282+
def test_provisioner_when_cuda_unavailable(self, mock_device_count):
283+
"""Test Provisioner defaults to 0 GPUs when CUDA is not available."""
284+
provisioner = Provisioner()
285+
286+
local_gpu_manager = provisioner._host_gpu_map[provisioner._this_host_id]
287+
available = local_gpu_manager.get_available_gpus()
288+
assert available == []
289+
assert len(available) == 0
290+
assert local_gpu_manager.max_device_count == 0

0 commit comments

Comments
 (0)