Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 64 additions & 9 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import socket
import uuid

import torch

from monarch._src.actor.actor_mesh import ActorMesh
from monarch._src.actor.shape import Extent

Expand Down Expand Up @@ -41,8 +43,19 @@ class _RemoteInfoFetcher(Actor):

@endpoint
def get_info(self) -> tuple[str, str]:
"""Returns hostname and port."""
return socket.gethostname(), _get_port()

@endpoint
def get_gpu_count(self) -> int:
"""Returns the number of GPUs available on this host."""
try:
gpu_count = torch.cuda.device_count()
except Exception:
# If torch is not available or CUDA is not available, assume no GPUs
gpu_count = 0
return gpu_count


class EnvSetter(Actor):
"""Actor to set environment variables on each proc in a mesh.
Expand Down Expand Up @@ -87,14 +100,26 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
fetcher = fetcher.slice(**singleton_slice)
# Fetcher should be a singleton at this point - call_one() will fail otherwise

host, port = await fetcher.get_info.call_one()

# Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
# await throwaway_procs.stop()
return host, port


async def get_host_gpus(host_mesh: HostMesh) -> int:
"""Returns the number of GPUs available on the host mesh."""
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
fetcher = throwaway_procs.spawn("_gpu_counter", _RemoteInfoFetcher)

# Reduce to a singleton
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
fetcher = fetcher.slice(**singleton_slice)

gpu_count = await fetcher.get_gpu_count.call_one()
return gpu_count


async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
"""Set environment variables on a proc mesh using EnvSetter actor.

Expand All @@ -112,17 +137,35 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
class GpuManager:
"""Tracks and assigns GPU devices on a host.

This currently mimics the `gpu_manager` in system_controllers - we will
consolidate as part of the "proper HostMesh integration" work.
Args:
available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
max_device_count: Maximum number of GPU devices on this host. Defaults to 8.

"""

def __init__(self, available_devices: set[int] | None = None):
def __init__(
self, available_devices: set[int] | None = None, max_device_count: int = 8
):
if available_devices is None:
available_devices = set(range(0, 8))
assert all(isinstance(x, int) for x in available_devices)
assert all(x >= 0 and x < 8 for x in available_devices)
available_devices = set(range(0, max_device_count))
else:
# Validate types first
assert all(
isinstance(x, int) for x in available_devices
), f"All device IDs must be integers, got: {available_devices}"
# When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
# adjust max_device_count to accommodate the highest device ID
if available_devices:
max_device_count = max(max(available_devices) + 1, max_device_count)

assert all(
isinstance(x, int) for x in available_devices
), f"All device IDs must be integers, got: {available_devices}"
assert all(
x >= 0 for x in available_devices
), f"All device IDs must be non-negative, got: {available_devices}"
self.available_gpus = available_devices
self.max_device_count = max_device_count

def get_available_gpus(self) -> list[str]:
"""Returns a list of available GPU devices."""
Expand Down Expand Up @@ -171,8 +214,18 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
f"Invalid CUDA_VISIBLE_DEVICES format: '{cuda_visible_devices}'. "
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
) from e

# Get the actual GPU count for the local host
try:
local_gpu_count = torch.cuda.device_count()
except Exception:
# If torch is not available or CUDA is not available, assume no GPUs
local_gpu_count = 0

self._host_gpu_map = {
self._this_host_id: GpuManager(available_local_devices),
self._this_host_id: GpuManager(
available_local_devices, max_device_count=local_gpu_count
),
}
self._proc_host_map = {}
self._host_mesh_map = {}
Expand Down Expand Up @@ -277,7 +330,9 @@ async def get_proc_mesh(
num_hosts=num_hosts,
)
host_id = uuid.uuid1()
gpu_manager = GpuManager()
# Get the GPU count from the remote host
remote_gpu_count = await get_host_gpus(host_mesh)
gpu_manager = GpuManager(max_device_count=remote_gpu_count)
self._host_gpu_map[host_id] = gpu_manager
host_mesh._host_id = host_id
else:
Expand Down
12 changes: 0 additions & 12 deletions src/forge/controller/system_controllers/__init__.py

This file was deleted.

73 changes: 0 additions & 73 deletions src/forge/controller/system_controllers/gpu_manager.py

This file was deleted.

Loading
Loading