Skip to content

Commit 63c19dc

Browse files
committed
Syntactic sugar for CG and awaiting
1 parent f4b8676 commit 63c19dc

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

megatron/core/inference/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22

33
import asyncio
4+
import functools
45
import multiprocessing
56
import sys
6-
77
import torch
88

9+
from contextlib import contextmanager
10+
911
from megatron.core.transformer.moe.moe_layer import MoELayer
1012
from megatron.core.utils import get_model_config
1113

@@ -139,6 +141,34 @@ def tensor_swap(x, src_idxs, dst_idxs):
139141
x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs]
140142

141143

144+
@contextmanager
145+
def use_cuda_graph(graph_cache: dict, graph_key):
146+
"""Syntactic sugar decorator to simplify CUDA graph capture and replay."""
147+
def deco(fn):
148+
@functools.wraps(fn)
149+
def wrapped(*args, graph_mode: str = "replay", **kwargs):
150+
assert graph_mode in ("record", "replay")
151+
if graph_mode == "record":
152+
# Do not allow for overwriting of existing graphs.
153+
assert graph_key not in graph_cache
154+
155+
g = torch.cuda.CUDAGraph()
156+
with torch.cuda.graph(g):
157+
fn(*args, **kwargs)
158+
graph_cache[graph_key] = g
159+
elif graph_mode == "replay":
160+
graph_cache[graph_key].replay()
161+
return wrapped
162+
return deco
163+
164+
async def torch_awaitable(stream: torch.cuda.Stream | None = None):
165+
"""Syntactic sugar for returning an awaitable handle for non-distributed torch."""
166+
if stream is None:
167+
stream = torch.cuda.current_stream()
168+
event = stream.record_event()
169+
while not event.query():
170+
await asyncio.sleep(0)
171+
142172
async def await_process_event(
143173
event: multiprocessing.Event, process: multiprocessing.Process, timeout: float = 1.0
144174
) -> None:

0 commit comments

Comments
 (0)