Skip to content

Commit 7b14df5

Browse files
author
Allen Wang
committed
separate out gpu counter
1 parent 8e2c06b commit 7b14df5

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

src/forge/controller/provisioner.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ class _RemoteInfoFetcher(Actor):
4242
"""An actor responsible for getting remote host information."""
4343

4444
@endpoint
45-
def get_info(self) -> tuple[str, str, int]:
46-
"""Returns hostname, port, and GPU count."""
45+
def get_info(self) -> tuple[str, str]:
46+
"""Returns hostname and port."""
47+
return socket.gethostname(), _get_port()
48+
49+
@endpoint
50+
def get_gpu_count(self) -> int:
51+
"""Returns the number of GPUs available on this host."""
4752
try:
4853
gpu_count = torch.cuda.device_count()
4954
except Exception:
5055
# If torch is not available or CUDA is not available, assume no GPUs
5156
gpu_count = 0
52-
return socket.gethostname(), _get_port(), gpu_count
57+
return gpu_count
5358

5459

5560
class EnvSetter(Actor):
@@ -85,8 +90,8 @@ def set_env(self, env_vars: dict[str, str]):
8590
os.environ[k] = v
8691

8792

88-
async def get_host_info(host_mesh: HostMesh) -> tuple[str, str, int]:
89-
"""Returns the host name, port, and GPU count of the host mesh."""
93+
async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
94+
"""Returns the host name and port of the host mesh."""
9095
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
9196
fetcher = throwaway_procs.spawn("_fetcher", _RemoteInfoFetcher)
9297

@@ -95,12 +100,24 @@ async def get_host_info(host_mesh: HostMesh) -> tuple[str, str, int]:
95100
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
96101
fetcher = fetcher.slice(**singleton_slice)
97102
# Fetcher should be a singleton at this point - call_one() will fail otherwise
98-
99-
host, port, gpu_count = await fetcher.get_info.call_one()
103+
host, port = await fetcher.get_info.call_one()
100104

101105
# Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
102106
# await throwaway_procs.stop()
103-
return host, port, gpu_count
107+
return host, port
108+
109+
110+
async def get_host_gpus(host_mesh: HostMesh) -> int:
111+
"""Returns the number of GPUs available on the host mesh."""
112+
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
113+
fetcher = throwaway_procs.spawn("_gpu_counter", _RemoteInfoFetcher)
114+
115+
# Reduce to a singleton
116+
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
117+
fetcher = fetcher.slice(**singleton_slice)
118+
119+
gpu_count = await fetcher.get_gpu_count.call_one()
120+
return gpu_count
104121

105122

106123
async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
@@ -313,18 +330,11 @@ async def get_proc_mesh(
313330
num_hosts=num_hosts,
314331
)
315332
host_id = uuid.uuid1()
316-
# Get host info including GPU count from the remote host
317-
host_addr, host_port, remote_gpu_count = await get_host_info(
318-
host_mesh
319-
)
333+
# Get the GPU count from the remote host
334+
remote_gpu_count = await get_host_gpus(host_mesh)
320335
gpu_manager = GpuManager(max_device_count=remote_gpu_count)
321336
self._host_gpu_map[host_id] = gpu_manager
322337
host_mesh._host_id = host_id
323-
# Use the fetched addr/port if not explicitly provided
324-
if addr is None:
325-
addr = host_addr
326-
if port is None:
327-
port = host_port
328338
else:
329339
host_id = host_mesh._host_id
330340
gpu_manager = self._host_gpu_map[host_id]
@@ -336,7 +346,7 @@ async def get_proc_mesh(
336346

337347
if with_gpus:
338348
if not addr or not port:
339-
addr, port, _ = await get_host_info(host_mesh)
349+
addr, port = await get_remote_info(host_mesh)
340350
gpu_ids = gpu_manager.get_gpus(num_procs)
341351

342352
env_vars["MASTER_ADDR"] = addr

0 commit comments

Comments
 (0)