|
3 | 3 | import asyncio |
4 | 4 | import multiprocessing |
5 | 5 | import sys |
6 | | - |
7 | 6 | import torch |
8 | 7 |
|
9 | 8 | from megatron.core.transformer.moe.moe_layer import MoELayer |
@@ -138,6 +137,28 @@ def tensor_swap(x, src_idxs, dst_idxs): |
138 | 137 | """ |
139 | 138 | x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs] |
140 | 139 |
|
| 140 | +def use_cuda_graph(graph_cache: dict, graph_key, fn): |
| 141 | + """Record-or-replay a CUDA graph for fn(). |
| 142 | +
|
| 143 | + On first call for a given graph_key, captures fn() into a CUDA graph. |
| 144 | + On subsequent calls with the same key, replays the cached graph. |
| 145 | + fn must be a zero-argument callable operating on static-address tensors. |
| 146 | + """ |
| 147 | + if graph_key in graph_cache: |
| 148 | + graph_cache[graph_key].replay() |
| 149 | + else: |
| 150 | + g = torch.cuda.CUDAGraph() |
| 151 | + with torch.cuda.graph(g): |
| 152 | + fn() |
| 153 | + graph_cache[graph_key] = g |
| 154 | + |
| 155 | +async def torch_awaitable(stream: torch.cuda.Stream | None = None): |
| 156 | + """Syntactic sugar for returning an awaitable handle for non-distributed torch.""" |
| 157 | + if stream is None: |
| 158 | + stream = torch.cuda.current_stream() |
| 159 | + event = stream.record_event() |
| 160 | + while not event.query(): |
| 161 | + await asyncio.sleep(0) |
141 | 162 |
|
142 | 163 | async def await_process_call(call, process: multiprocessing.Process, timeout: float = 1.0): |
143 | 164 | """Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure. |
|
0 commit comments