|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | # pyre-unsafe
|
8 |
| -import asyncio |
9 | 8 | import ctypes
|
10 | 9 | import functools
|
11 | 10 | import logging
|
@@ -118,27 +117,25 @@ def _get_addr_and_size(buf: torch.Tensor | memoryview) -> tuple[int, int]:
|
118 | 117 |
|
119 | 118 | class RdmaController(Actor):
|
120 | 119 | def __init__(self) -> None:
|
121 |
| - self._managers: Dict[ProcMesh, _RdmaManager] = {} |
122 |
| - self._lock = asyncio.Lock() |
| 120 | + self._manager_futures: Dict[ProcMesh, Future[_RdmaManager]] = {} |
123 | 121 |
|
124 | 122 | @endpoint
|
125 | 123 | async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None:
|
126 | 124 | # Note: RdmaController acts as coordinator and can run on any node
|
127 | 125 | # The RDMA support check should happen on the target proc_mesh nodes, not on RdmaController's node
|
128 | 126 |
|
129 |
| - if proc_mesh in self._managers: |
130 |
| - return |
131 |
| - |
132 |
| - async with self._lock: |
133 |
| - if proc_mesh not in self._managers: |
134 |
| - self._managers[proc_mesh] = none_throws( |
135 |
| - await Future( |
136 |
| - coro=_RdmaManager.create_rdma_manager_nonblocking( |
137 |
| - await Future(coro=proc_mesh._proc_mesh.task()) |
138 |
| - ) |
139 |
| - ) |
| 127 | + if proc_mesh not in self._manager_futures: |
| 128 | + |
| 129 | + async def create_manager() -> _RdmaManager: |
| 130 | + proc_mesh_result = await Future(coro=proc_mesh._proc_mesh.task()) |
| 131 | + return none_throws( |
| 132 | + await _RdmaManager.create_rdma_manager_nonblocking(proc_mesh_result) |
140 | 133 | )
|
141 | 134 |
|
| 135 | + self._manager_futures[proc_mesh] = Future(coro=create_manager()) |
| 136 | + |
| 137 | + await self._manager_futures[proc_mesh] |
| 138 | + |
142 | 139 |
|
143 | 140 | @functools.cache
|
144 | 141 | def _check_cuda_expandable_segments_enabled() -> bool:
|
|
0 commit comments