|
1 | 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import functools |
4 | 5 | import multiprocessing |
5 | 6 | import sys |
6 | | - |
7 | 7 | import torch |
8 | 8 |
|
| 9 | +from contextlib import contextmanager |
| 10 | + |
9 | 11 | from megatron.core.transformer.moe.moe_layer import MoELayer |
10 | 12 | from megatron.core.utils import get_model_config |
11 | 13 |
|
@@ -139,6 +141,34 @@ def tensor_swap(x, src_idxs, dst_idxs): |
139 | 141 | x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs] |
140 | 142 |
|
141 | 143 |
|
| 144 | +@contextmanager |
| 145 | +def use_cuda_graph(graph_cache: dict, graph_key): |
| 146 | + """Syntactic sugar decorator to simplify CUDA graph capture and replay.""" |
| 147 | + def deco(fn): |
| 148 | + @functools.wraps(fn) |
| 149 | + def wrapped(*args, graph_mode: str = "replay", **kwargs): |
| 150 | + assert graph_mode in ("record", "replay") |
| 151 | + if graph_mode == "record": |
| 152 | + # Do not allow for overwriting of existing graphs. |
| 153 | + assert graph_key not in graph_cache |
| 154 | + |
| 155 | + g = torch.cuda.CUDAGraph() |
| 156 | + with torch.cuda.graph(g): |
| 157 | + fn(*args, **kwargs) |
| 158 | + graph_cache[graph_key] = g |
| 159 | + elif graph_mode == "replay": |
| 160 | + graph_cache[graph_key].replay() |
| 161 | + return wrapped |
| 162 | + return deco |
| 163 | + |
| 164 | +async def torch_awaitable(stream: torch.cuda.Stream | None = None): |
| 165 | + """Syntactic sugar for returning an awaitable handle for non-distributed torch.""" |
| 166 | + if stream is None: |
| 167 | + stream = torch.cuda.current_stream() |
| 168 | + event = stream.record_event() |
| 169 | + while not event.query(): |
| 170 | + await asyncio.sleep(0) |
| 171 | + |
142 | 172 | async def await_process_event( |
143 | 173 | event: multiprocessing.Event, process: multiprocessing.Process, timeout: float = 1.0 |
144 | 174 | ) -> None: |
|
0 commit comments