Skip to content

Commit 3c2e6f8

Browse files
authored
[KERNELS] vllm compatible version of CUDA Graph tracing for expert parallelism (#8563)
1 parent 2156b05 commit 3c2e6f8

File tree

5 files changed

+253
-137
lines changed

5 files changed

+253
-137
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
7070

7171
input_x = torch.randn((batch // DP, dim1), device=dev)
7272
expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev))
73+
triton_dist.initialize_matmul_ogs(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype)
74+
7375
# run layer
7476
fpath = Path(tempfile.mktemp())
7577
proton.start(str(fpath), hook="triton")
@@ -79,7 +81,8 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
7981
if n_expts_tot > 1: # sparse
8082
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
8183
x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP,
82-
TP=TP, expt_assignment=expt_assignment)
84+
TP=TP, expt_assignment=expt_assignment,
85+
mode="ep_sharding")
8386
else: # dense
8487
x = triton_dist.all_gather(input_x, dim=0)
8588
rdata, gather_indx, scatter_indx, metadata = None, None, None, None

python/triton_kernels/bench/distributed.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq
1717
from triton_kernels.tensor_details import layout
1818
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata
19-
from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment
19+
from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment, symm_mem_pool
2020

2121
from bench_utils import quantize_weight
2222

@@ -40,6 +40,31 @@ def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> O
4040
return make_expt_assignment(EP, n_expts_tot, expt_dict, device)
4141

4242

43+
def initialize_matmul_ogs(
44+
batch: int,
45+
dim1: int,
46+
dim2: int,
47+
n_expts_act: int,
48+
n_expts_tot: int,
49+
dtype: torch.dtype,
50+
) -> None:
51+
if not _is_distributed_launch():
52+
return
53+
world_size = dist.get_world_size()
54+
device = torch.cuda.current_device()
55+
symm_mem_pool.initialize_matmul_ogs(
56+
n_tokens_global=batch,
57+
d_input=dim1,
58+
d_model=dim2,
59+
n_expts_act=n_expts_act,
60+
n_expts_tot=n_expts_tot,
61+
n_ranks=world_size,
62+
dtype=dtype,
63+
group=dist.group.WORLD,
64+
device=device,
65+
)
66+
67+
4368
def setup() -> Tuple[int, int]:
4469
if _is_distributed_launch():
4570
world_size = int(os.environ["WORLD_SIZE"])
@@ -112,11 +137,18 @@ def reduce_scatter(
112137
# TODO: clean up duplicate code with triton_kernels.test_distributed.py
113138
# TODO: Support nonuniform expert assignment
114139
def routing(
115-
x, logits, n_expts_act, sm_first: bool = False, y_indx: Optional[torch.Tensor] = None, EP: int = 1, TP: int = 1,
116-
expt_assignment: Optional[ExptAssignment] = None, mode: str = "ep_sharding"
140+
x,
141+
logits,
142+
n_expts_act,
143+
sm_first: bool = False,
144+
y_indx: Optional[torch.Tensor] = None,
145+
EP: int = 1,
146+
TP: int = 1,
147+
expt_assignment: Optional[ExptAssignment] = None,
148+
mode: Optional[str] = None,
117149
) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]:
118150
n_expts_tot = logits.shape[-1]
119-
if _is_distributed_launch():
151+
if _is_distributed_launch() and mode:
120152
if mode == "ep_sharding":
121153
if not expt_assignment:
122154
raise ValueError("expt_assignment must be provided for distributed routing.")
@@ -150,6 +182,7 @@ def routing(
150182
else:
151183
raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.")
152184
else:
185+
# If mode is not specified or we have a single process, we do single-GPU routing.
153186
logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first)
154187
dispatch_indx = logits.mask_metadata.col_sorted_indx
155188
combine_indx = logits.mask_metadata.row_sorted_indx
@@ -262,6 +295,17 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
262295
xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype])
263296
x0 = all_gather(xd, dim=0)
264297
expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev))
298+
symm_mem_pool.initialize_matmul_ogs(
299+
n_tokens_global=batch,
300+
d_input=dim1,
301+
d_model=dim2,
302+
n_expts_act=n_expts_act,
303+
n_expts_tot=n_expts_tot,
304+
n_ranks=world_size,
305+
dtype=x0.dtype,
306+
group=dist.group.WORLD,
307+
device=torch.cuda.current_device(),
308+
)
265309

266310
# single-GPU pass
267311
def single(x):
@@ -279,7 +323,8 @@ def distributed(x):
279323
xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype)
280324
if n_expts_tot > 1: # sparse
281325
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
282-
x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment)
326+
x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment,
327+
mode="ep_sharding")
283328
else: # dense
284329
x = all_gather(x, dim=0)
285330
rdata = gi = si = metadata = None
@@ -322,12 +367,12 @@ def distributed(x):
322367
)
323368
def test_mlp_mp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, monkeypatch):
324369
parallelism = TP * EP
370+
if is_hip():
371+
pytest.skip("[TODO] HIP support for distributed MoE.")
325372
if torch.cuda.device_count() < parallelism:
326373
pytest.skip(f"Test requires at least {parallelism} GPUs.")
327374
if is_cuda() and not cuda_capability_geq(9, 0):
328375
pytest.skip("Test requires CUDA compute capability >= 9.0.")
329-
if is_hip() and get_cdna_version() == 4 and EP > 1:
330-
pytest.skip("[TODO] Unknown issue with CDNA 4 and EP > 1")
331376
if TP > 1:
332377
pytest.skip("[TODO] TP > 1 is not supported yet in distributed mode.")
333378

python/triton_kernels/tests/test_distributed.py

Lines changed: 22 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
import torch
66
import torch.distributed as dist
7-
import torch.distributed._symmetric_memory as symm_mem
87
import torch.multiprocessing as mp
98
import triton
10-
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment
9+
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, symm_mem_pool
1110
from triton_kernels.reduce import reduce
1211
from triton_kernels.topk import topk
1312
from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx
@@ -166,99 +165,6 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex
166165
return z_dp_local
167166

168167

169-
def _capture_with_prepared_symm_mem(fn):
170-
"""
171-
Run `fn` once to record symmetric-memory allocations, preallocate them outside the CUDA graph,
172-
and capture a CUDA graph that reuses the recorded buffers.
173-
"""
174-
orig_symm_empty = symm_mem.empty
175-
orig_symm_rendezvous = symm_mem.rendezvous
176-
recorded_empty_calls = []
177-
recorded_rendezvous_calls = []
178-
buffer_id_to_index = {}
179-
180-
def recording_empty(*args, **kwargs):
181-
buf = orig_symm_empty(*args, **kwargs)
182-
idx = len(recorded_empty_calls)
183-
buffer_id_to_index[id(buf)] = idx
184-
recorded_empty_calls.append((args, dict(kwargs)))
185-
return buf
186-
187-
def recording_rendezvous(buf, *args, **kwargs):
188-
buf_id = id(buf)
189-
if buf_id not in buffer_id_to_index:
190-
raise RuntimeError("symm_mem.rendezvous called on unknown buffer")
191-
hdl = orig_symm_rendezvous(buf, *args, **kwargs)
192-
recorded_rendezvous_calls.append((buffer_id_to_index[buf_id], args, dict(kwargs)))
193-
return hdl
194-
195-
symm_mem.empty = recording_empty
196-
symm_mem.rendezvous = recording_rendezvous
197-
try:
198-
warmup_result = fn()
199-
finally:
200-
symm_mem.empty = orig_symm_empty
201-
symm_mem.rendezvous = orig_symm_rendezvous
202-
203-
prepared_empty_buffers = [orig_symm_empty(*args, **kwargs) for args, kwargs in recorded_empty_calls]
204-
prepared_handles = [
205-
orig_symm_rendezvous(prepared_empty_buffers[idx], *args, **kwargs)
206-
for idx, args, kwargs in recorded_rendezvous_calls
207-
]
208-
209-
capture_stream = torch.cuda.Stream()
210-
graph = torch.cuda.CUDAGraph()
211-
212-
if recorded_empty_calls:
213-
empty_idx = 0
214-
rendezvous_idx = 0
215-
216-
def reuse_empty(*args, **kwargs):
217-
nonlocal empty_idx
218-
if empty_idx >= len(prepared_empty_buffers):
219-
raise RuntimeError("symm_mem.empty called more times than recorded")
220-
expected_args, expected_kwargs = recorded_empty_calls[empty_idx]
221-
if expected_args != args or expected_kwargs != kwargs:
222-
raise RuntimeError("symm_mem.empty called with unexpected arguments")
223-
buf = prepared_empty_buffers[empty_idx]
224-
empty_idx += 1
225-
return buf
226-
227-
def reuse_rendezvous(buf, *args, **kwargs):
228-
nonlocal rendezvous_idx
229-
if rendezvous_idx >= len(prepared_handles):
230-
raise RuntimeError("symm_mem.rendezvous called more times than recorded")
231-
expected_empty_idx, expected_args, expected_kwargs = recorded_rendezvous_calls[rendezvous_idx]
232-
expected_buf = prepared_empty_buffers[expected_empty_idx]
233-
if buf is not expected_buf:
234-
raise RuntimeError("symm_mem.rendezvous received unexpected buffer")
235-
if expected_args != args or expected_kwargs != kwargs:
236-
raise RuntimeError("symm_mem.rendezvous called with unexpected arguments")
237-
handle = prepared_handles[rendezvous_idx]
238-
rendezvous_idx += 1
239-
return handle
240-
241-
symm_mem.empty = reuse_empty
242-
symm_mem.rendezvous = reuse_rendezvous
243-
try:
244-
with torch.cuda.stream(capture_stream):
245-
with torch.cuda.graph(graph):
246-
fn()
247-
finally:
248-
symm_mem.empty = orig_symm_empty
249-
symm_mem.rendezvous = orig_symm_rendezvous
250-
else:
251-
with torch.cuda.stream(capture_stream):
252-
with torch.cuda.graph(graph):
253-
fn()
254-
255-
# Keep references alive for as long as the graph exists.
256-
graph._symm_mem_buffers = prepared_empty_buffers
257-
graph._symm_mem_handles = prepared_handles
258-
graph._capture_stream = capture_stream
259-
return warmup_result, graph
260-
261-
262168
def _run_expert_sharding(rank, world_size, *, n_tokens, d_model, n_expts_tot, n_expts_act, affinity_mode):
263169
torch.manual_seed(0)
264170

@@ -303,17 +209,33 @@ def run_mixture():
303209
y_indx=y_indx_global,
304210
)
305211

306-
# test cuda graph capture + replay with symmetric memory
307-
y_dp_local_tri, graph = _capture_with_prepared_symm_mem(run_mixture)
212+
symm_mem_pool.initialize_matmul_ogs(
213+
n_tokens_global=n_tokens_global,
214+
d_input=d_model,
215+
d_model=d_model,
216+
n_expts_act=n_expts_act,
217+
n_expts_tot=n_expts_tot,
218+
dtype=torch.bfloat16,
219+
n_ranks=world_size,
220+
group=dist.group.WORLD,
221+
device=dev,
222+
)
223+
y_dp_local_tri = run_mixture()
308224
y_global_tri = torch.empty_like(y_global_ref)
309225

310226
# Validate warmup run.
311227
dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri)
312228
triton.testing.assert_close(y_global_ref, y_global_tri)
313229

314-
# Validate first replay with unchanged inputs.
315-
graph.replay()
316-
dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri)
230+
# Validate cuda graph capture + replay.
231+
g = torch.cuda.CUDAGraph()
232+
stream = torch.cuda.Stream()
233+
with torch.cuda.stream(stream):
234+
with torch.cuda.graph(g):
235+
y_dp_local_tri_graph = run_mixture()
236+
237+
g.replay()
238+
dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri_graph)
317239
triton.testing.assert_close(y_global_ref, y_global_tri)
318240

319241

0 commit comments

Comments
 (0)