Skip to content

Commit a7d508c

Browse files
committed
Syntactic sugar for CG and awaiting
1 parent c59b47c commit a7d508c

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

megatron/core/inference/utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import multiprocessing
55
import sys
6-
76
import torch
87

98
from megatron.core.transformer.moe.moe_layer import MoELayer
@@ -138,6 +137,28 @@ def tensor_swap(x, src_idxs, dst_idxs):
138137
"""
139138
x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs]
140139

140+
def use_cuda_graph(graph_cache: dict, graph_key, fn):
141+
"""Record-or-replay a CUDA graph for fn().
142+
143+
On first call for a given graph_key, captures fn() into a CUDA graph.
144+
On subsequent calls with the same key, replays the cached graph.
145+
fn must be a zero-argument callable operating on static-address tensors.
146+
"""
147+
if graph_key in graph_cache:
148+
graph_cache[graph_key].replay()
149+
else:
150+
g = torch.cuda.CUDAGraph()
151+
with torch.cuda.graph(g):
152+
fn()
153+
graph_cache[graph_key] = g
154+
155+
async def torch_awaitable(stream: torch.cuda.Stream | None = None):
156+
"""Syntactic sugar for returning an awaitable handle for non-distributed torch."""
157+
if stream is None:
158+
stream = torch.cuda.current_stream()
159+
event = stream.record_event()
160+
while not event.query():
161+
await asyncio.sleep(0)
141162

142163
async def await_process_call(call, process: multiprocessing.Process, timeout: float = 1.0):
143164
"""Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure.

0 commit comments

Comments
 (0)