Skip to content

Commit a68e831

Browse files
committed
fix mapping issue + a repro test
1 parent a668434 commit a68e831

File tree

4 files changed

+64
-12
lines changed

4 files changed

+64
-12
lines changed

assets/versions.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ PYTORCH_VERSION="2.9.0.dev20250905"
1414
VLLM_BRANCH="v0.10.0"
1515

1616
# Commit hashes
17-
MONARCH_COMMIT="2f14096083b1cc1dac6ae15220e4135bc23f9dd3"
17+
MONARCH_COMMIT="main"
1818
TORCHTITAN_COMMIT="d0e25450bcac2332359b13fbda430dc701f073d4"
1919
TORCHSTORE_COMMIT="662299faf4fd50ee30bd9aa3f4ce8c0e2db1d310"

src/forge/actors/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
154154
worker_procs = await get_proc_mesh(process_config=process_config)
155155

156156
# Then, grab a single host from the workers...
157-
host_mesh = await host_mesh_from_proc(worker_procs)
157+
host_mesh = await host_mesh_from_proc(worker_procs._uid)
158158
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
159159
host_mesh = host_mesh.slice(**singleton_slice)
160160

src/forge/controller/provisioner.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ def bootstrap(env: dict[str, str]):
294294
per_host={"procs": num_procs},
295295
bootstrap=functools.partial(bootstrap, env=env_vars),
296296
)
297+
uid = str(uuid.uuid4())
298+
# Generate a unique ID to map procmesh to hostmesh
299+
procs._uid = uid
300+
print(f"Allocating procmesh with uid={uid}")
301+
print(f"Allocating procs._uid: {procs._uid}")
297302

298303
if with_gpus:
299304
# Set up environment variables for PyTorch distributed...
@@ -319,28 +324,35 @@ def bootstrap(env: dict[str, str]):
319324
self._server_names.append(server_name)
320325
self._proc_server_map[procs] = server_name
321326

322-
self._proc_host_map[procs] = host_mesh
327+
self._proc_host_map[uid] = host_mesh
323328

324329
# Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
325330
# When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
326331
if not FORGE_DISABLE_METRICS.get_value():
327332
from forge.observability.metric_actors import get_or_create_metric_logger
328333

329334
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
330-
return procs
335+
336+
print(f"Returning procmesh with uid={uid}")
337+
print(f"Returning procs._uid: {procs._uid}")
338+
return procs, uid
331339

332340
@endpoint
333-
async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
334-
if proc_mesh not in self._proc_host_map:
341+
async def host_mesh_from_proc(self, uid: str | None):
342+
# uid: str | None = getattr(proc_mesh, "_uid", None)
343+
print(f"self._proc_host_map: {self._proc_host_map}")
344+
print(f"proc_mesh._uid: {uid}")
345+
if uid is None or uid not in self._proc_host_map:
335346
raise ValueError(
336347
"The proc mesh was not allocated with an associated hostmesh."
337348
)
338-
return self._proc_host_map[proc_mesh]
349+
return self._proc_host_map[uid]
339350

340351
@endpoint
341352
async def stop_proc_mesh(self, proc_mesh: ProcMesh):
342353
"""Stops a proc mesh."""
343-
if proc_mesh not in self._proc_host_map:
354+
uid: str | None = getattr(proc_mesh, "_uid", None)
355+
if uid is None or uid not in self._proc_host_map:
344356
logger.warning(
345357
f"proc mesh {proc_mesh} was requested to be stopped, but was either already stopped or "
346358
"was never registered with the provisioner."
@@ -363,7 +375,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
363375
if proc_mesh in self._proc_server_map:
364376
server_name = self._proc_server_map[proc_mesh]
365377
commands.kill(server_name)
366-
del self._proc_host_map[proc_mesh]
378+
del self._proc_host_map[uid]
367379

368380
@endpoint
369381
def register_service(self, service: "ServiceInterface") -> None:
@@ -464,7 +476,7 @@ async def get_proc_mesh(
464476
465477
"""
466478
provisioner = await get_or_create_provisioner()
467-
return await provisioner.get_proc_mesh.call_one(
479+
procs, uid = await provisioner.get_proc_mesh.call_one(
468480
num_procs=process_config.procs,
469481
with_gpus=process_config.with_gpus,
470482
num_hosts=process_config.hosts,
@@ -474,17 +486,20 @@ async def get_proc_mesh(
474486
port=port,
475487
addr=addr,
476488
)
489+
setattr(procs, "_uid", uid)
490+
print(f"Setting procs._uid: {procs._uid}")
491+
return procs
477492

478493

479-
async def host_mesh_from_proc(proc_mesh: ProcMesh):
494+
async def host_mesh_from_proc(uid: str | None):
480495
"""Returns the host mesh that allocated the original proc_mesh.
481496
482497
This functionality will be enabled in Monarch, so this is a temporary
483498
API.
484499
485500
"""
486501
provisioner = await get_or_create_provisioner()
487-
return await provisioner.host_mesh_from_proc.call_one(proc_mesh)
502+
return await provisioner.host_mesh_from_proc.call_one(uid)
488503

489504

490505
async def register_service(service: "ServiceInterface") -> None:

test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Meta Platforms, Inc.
2+
# All rights reserved.
3+
#
4+
# Minimal repro: Provisioner host_mesh_from_proc() UID mapping bug
5+
#
6+
# Run this with:
7+
# python -m forge.tests.test_provisioner_uid_mapping
8+
9+
import asyncio
10+
11+
# import pytest
12+
13+
from forge.controller.provisioner import (
14+
get_or_create_provisioner,
15+
get_proc_mesh,
16+
stop_proc_mesh,
17+
)
18+
from forge.types import ProcessConfig
19+
20+
21+
# @pytest.mark.asyncio
22+
async def test_provisioner_host_mesh_lookup_uid_mapping():
23+
prov = await get_or_create_provisioner()
24+
pm = await get_proc_mesh(
25+
ProcessConfig(procs=1, with_gpus=False, hosts=None, mesh_name="uid_repro")
26+
)
27+
# UID is attached locally by the helper
28+
assert hasattr(pm, "_uid") and pm._uid, "missing _uid on returned ProcMesh"
29+
print(f"✅ got ProcMesh with UID {pm._uid}")
30+
hm = await prov.host_mesh_from_proc.call_one(pm._uid) # if pass pm, _uid is None
31+
assert hm is not None
32+
await stop_proc_mesh(pm)
33+
print("✅ repro passed")
34+
35+
36+
if __name__ == "__main__":
37+
asyncio.run(test_provisioner_host_mesh_lookup_uid_mapping())

0 commit comments

Comments
 (0)