|
5 | 5 | import multiprocessing |
6 | 6 | import sys |
7 | 7 | import time |
8 | | - |
9 | 8 | import torch |
10 | 9 |
|
11 | 10 | from megatron.core.transformer.moe.moe_layer import MoELayer |
@@ -218,6 +217,31 @@ def tensor_swap(x, src_idxs, dst_idxs): |
218 | 217 | x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs] |
219 | 218 |
|
220 | 219 |
|
| 220 | +def use_cuda_graph(graph_cache: dict, graph_key, fn): |
| 221 | + """Record-or-replay a CUDA graph for fn(). |
| 222 | +
|
| 223 | + On first call for a given graph_key, captures fn() into a CUDA graph. |
| 224 | + On subsequent calls with the same key, replays the cached graph. |
| 225 | + fn must be a zero-argument callable operating on static-address tensors. |
| 226 | + """ |
| 227 | + if graph_key in graph_cache: |
| 228 | + graph_cache[graph_key].replay() |
| 229 | + else: |
| 230 | + g = torch.cuda.CUDAGraph() |
| 231 | + with torch.cuda.graph(g): |
| 232 | + fn() |
| 233 | + graph_cache[graph_key] = g |
| 234 | + |
| 235 | + |
| 236 | +async def torch_awaitable(stream: torch.cuda.Stream | None = None): |
| 237 | + """Syntactic sugar for returning an awaitable handle for non-distributed torch.""" |
| 238 | + if stream is None: |
| 239 | + stream = torch.cuda.current_stream() |
| 240 | + event = stream.record_event() |
| 241 | + while not event.query(): |
| 242 | + await asyncio.sleep(0) |
| 243 | + |
| 244 | + |
221 | 245 | async def await_process_call(call, process: multiprocessing.Process, timeout: float = 1.0): |
222 | 246 | """Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure. |
223 | 247 |
|
|
0 commit comments