Skip to content

Commit 37c081a

Browse files
committed
Syntactic sugar for CG and awaiting
1 parent 60bcc92 commit 37c081a

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

megatron/core/inference/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import multiprocessing
66
import sys
77
import time
8-
98
import torch
109

1110
from megatron.core.transformer.moe.moe_layer import MoELayer
@@ -218,6 +217,31 @@ def tensor_swap(x, src_idxs, dst_idxs):
218217
x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs]
219218

220219

220+
def use_cuda_graph(graph_cache: dict, graph_key, fn):
221+
"""Record-or-replay a CUDA graph for fn().
222+
223+
On first call for a given graph_key, captures fn() into a CUDA graph.
224+
On subsequent calls with the same key, replays the cached graph.
225+
fn must be a zero-argument callable operating on static-address tensors.
226+
"""
227+
if graph_key in graph_cache:
228+
graph_cache[graph_key].replay()
229+
else:
230+
g = torch.cuda.CUDAGraph()
231+
with torch.cuda.graph(g):
232+
fn()
233+
graph_cache[graph_key] = g
234+
235+
236+
async def torch_awaitable(stream: torch.cuda.Stream | None = None):
237+
"""Syntactic sugar for returning an awaitable handle for non-distributed torch."""
238+
if stream is None:
239+
stream = torch.cuda.current_stream()
240+
event = stream.record_event()
241+
while not event.query():
242+
await asyncio.sleep(0)
243+
244+
221245
async def await_process_call(call, process: multiprocessing.Process, timeout: float = 1.0):
222246
"""Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure.
223247

0 commit comments

Comments
 (0)