Skip to content

Commit bea27e3

Browse files
authored
[triton_kernels] some clean-up of the routing (#8330)
1 parent 9117748 commit bea27e3

File tree

8 files changed

+177
-225
lines changed

8 files changed

+177
-225
lines changed

python/triton_kernels/bench/distributed.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
ScatterIndx,
1919
compute_expt_data_torch,
2020
topk_torch,
21-
prune_routing,
2221
routing_from_bitmatrix,
2322
)
2423
from triton_kernels.topk import topk
2524
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
26-
from triton_kernels.routing_details._routing_compute import _routing_clear_bitmatrix
2725
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq
2826
from triton_kernels.tensor_details import layout
2927
from triton_kernels.tensor import Bitmatrix
@@ -291,6 +289,46 @@ def pack_bitmatrix(
291289
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
292290

293291

292+
@triton.jit
293+
def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr):
294+
pid_m = tl.program_id(0)
295+
cutoff_word = cutoff // 32
296+
cutoff_bit = cutoff % 32
297+
cutoff_mask = (1 << (cutoff_bit)) - 1
298+
for start_n in range(0, shape_bn, BLOCK_N):
299+
offs_n = start_n + tl.arange(0, BLOCK_N)
300+
values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn)
301+
values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
302+
values = tl.where(offs_n > cutoff_word, 0, values)
303+
tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn)
304+
305+
306+
class PruneRouting(torch.autograd.Function):
307+
308+
@staticmethod
309+
def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
310+
from triton_kernels.compaction import compaction
311+
n_tokens_pad = expt_scal.shape[0]
312+
assert n_expts_tot % simulated_ep == 0
313+
_routing_clear_bitmatrix[(n_tokens_pad, )](
314+
bitmatrix.storage.data,
315+
bitmatrix.storage.data.stride(0),
316+
bitmatrix.storage.data.stride(1),
317+
bitmatrix.storage.data.shape[1],
318+
n_expts_tot // simulated_ep,
319+
BLOCK_N=512,
320+
)
321+
# perform compaction to update expt_scal / expt_indx
322+
expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
323+
n_expts_tot = n_expts_tot // simulated_ep
324+
bitmatrix.shape[-1] = n_expts_tot
325+
return expt_scal, expt_indx, bitmatrix
326+
327+
328+
def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
329+
return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep)
330+
331+
294332
def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1):
295333
_, n_expts_tot = logits.shape
296334

@@ -354,7 +392,7 @@ def routing(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None,
354392
else:
355393
raise ValueError(f"Unknown backend: {backend}")
356394
else:
357-
return x, *triton_kernels.routing.routing(logits, n_expts_act, sm_first, expt_indx, EP, n_rows), None
395+
return x, *triton_kernels.routing.routing(logits, n_expts_act, sm_first, expt_indx, n_rows), None
358396

359397

360398
# The following dummy methods simulate the behavior of distributed operations

python/triton_kernels/tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import tempfile
3+
import os
34

45

56
def pytest_addoption(parser):
@@ -29,3 +30,11 @@ def fresh_triton_cache():
2930
with knobs.cache.scope(), knobs.runtime.scope():
3031
knobs.cache.dir = tmpdir
3132
yield tmpdir
33+
34+
35+
def pytest_configure(config):
36+
worker_id = os.environ.get("PYTEST_XDIST_WORKER")
37+
if worker_id is not None and worker_id.startswith("gw"):
38+
import torch
39+
gpu_id = int(worker_id[2:]) # map gw0 → 0, gw1 → 1, ...
40+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id % torch.cuda.device_count())

python/triton_kernels/tests/test_matmul.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,16 @@ def mask_indx(idx, n_expts_act):
4545
return idx
4646

4747

48-
def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"):
48+
def init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, device="cuda"):
4949
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True)
50-
routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards)
50+
routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act)
5151
routing_data.gate_scal = None
5252
gather_idx = gather_idx if do_gather else None
5353
scatter_idx = scatter_idx if do_scatter else None
54-
# TODO: re-enable
55-
# if do_gather and do_scatter and n_expts_act == 1 and n_expt_shards == 1:
56-
# scatter_idx = mask_indx(scatter_idx, n_expts_act)
5754
return m, routing_data, gather_idx, scatter_idx
5855

5956

60-
def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype,
57+
def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, act_dtype, weight_dtype,
6158
has_y_gammas, requires_grad=True, device="cuda",
6259
inner_expt_opt=None, padding_block_k=None):
6360
torch.manual_seed(0)
@@ -70,7 +67,7 @@ def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, n_
7067
else:
7168
in_m = m * (n_expts_act if gindx is None else 1)
7269
shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k)
73-
shape_batch = tuple() if (mode == "plain" or inner_expt_opt is not None) else (n_expts_tot // n_expt_shards, )
70+
shape_batch = tuple() if (mode == "plain" or inner_expt_opt is not None) else (n_expts_tot, )
7471
x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad)
7572
w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
7673
bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
@@ -194,7 +191,6 @@ class Case:
194191
weight_dtype_str: str
195192
n_expts_tot: int = 1
196193
n_expts_act: int = 1
197-
n_expt_shards: int = 1
198194
split_k: int = 1
199195
hbm_swizzling: bool = False
200196
epilogue_subtile: Union[int, None] = None
@@ -216,10 +212,6 @@ class Case:
216212
Case(5, 7, 0, "batched", "float16", "float16"),
217213
# Non-mx types:
218214
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),
219-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),
220-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4),
221-
Case(400, 300, 500, "ragged", "float16", "float16", 32, 4, n_expt_shards=4),
222-
Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2),
223215
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
224216
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
225217
Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1),
@@ -235,8 +227,6 @@ class Case:
235227
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2),
236228
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4),
237229
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2),
238-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2),
239-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2),
240230
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2),
241231
Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1),
242232
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),
@@ -291,19 +281,17 @@ class Case:
291281
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
292282
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
293283
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),
294-
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2),
295284
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2),
296285
Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"),
297286
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
298287
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
299-
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
300288
] + [
301-
Case(320, 400, 400, mode, dtype, dtype, n_expts_tot, n_expts_act, n_expt_shards=n_expt_shards,
289+
Case(320, 400, 400, mode, dtype, dtype, n_expts_tot, n_expts_act,
302290
x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose)
303-
for (mode, n_expts_tot, n_expts_act, n_expt_shards) in (
304-
("batched", 1, 1, 1),
305-
("ragged", 8, 4, 1),
306-
("ragged", 32, 4, 4),
291+
for (mode, n_expts_tot, n_expts_act) in (
292+
("batched", 1, 1),
293+
("ragged", 8, 4),
294+
("ragged", 32, 4),
307295
)
308296
for dtype in ("float16", "float8_e5m2")
309297
for x_transpose in (False, True)
@@ -326,7 +314,7 @@ class Case:
326314
@pytest.mark.parametrize("has_y_gammas", [False, True])
327315
@pytest.mark.parametrize("is_persistent", [False, True])
328316
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
329-
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
317+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
330318
x_transpose, w_transpose, y_transpose,
331319
device, opt_flags_scope, fresh_knobs):
332320
# TODO: remove when Triton FP8 supports proper RTNE
@@ -424,17 +412,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
424412
weight_dtype = dtype_str_to_torch(weight_dtype_str)
425413
act_dtype = dtype_str_to_torch(act_dtype_str)
426414
precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp,
427-
n_expts_tot // n_expt_shards, expt_is_inner, device=device)
415+
n_expts_tot, expt_is_inner, device=device)
428416
# precision_opt.x_pad_trans_requires_flexpoint = False
429417
if mode == "ragged":
430-
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
418+
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter,
431419
device=device)
432420
else:
433421
rdata = gindx = sindx = None
434422

435423
padding_block_k = 32
436424
x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act,
437-
n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype, #
425+
mode, torch.bfloat16 if act_mxfp8 else act_dtype, #
438426
torch.bfloat16 if weight_mxfp else weight_dtype,
439427
has_y_gammas, requires_grad=test_bwd, device=device,
440428
inner_expt_opt=inner_expt_opt, padding_block_k=padding_block_k)
@@ -446,9 +434,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
446434
w_tri = w_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd)
447435
if y_transpose:
448436
if mode == "batched":
449-
yT_shape = (n_expts_tot // n_expt_shards, n, x_tri.shape[-2])
437+
yT_shape = (n_expts_tot, n, x_tri.shape[-2])
450438
elif expt_is_inner:
451-
yT_shape = (n_expts_tot // n_expt_shards, n, k)
439+
yT_shape = (n_expts_tot, n, k)
452440
elif sindx is not None:
453441
yT_shape = (n, m)
454442
else:
@@ -549,20 +537,6 @@ def scale(val, scal):
549537
assert val.ndim == 3
550538
return val / scal[:, None, None]
551539

552-
if n_expt_shards > 1:
553-
if do_scatter:
554-
indx = sindx.dst_indx[sindx.dst_indx != -1]
555-
ref_y = ref_y[indx // n_expts_act, :]
556-
if act_is_float8:
557-
tri_y = tri_y.view(torch.int8)
558-
tri_y = tri_y[indx // n_expts_act, :]
559-
if act_is_float8:
560-
tri_y = tri_y.view(act_dtype)
561-
elif not expt_is_inner:
562-
n_rows = rdata.expt_hist.sum()
563-
assert n_rows > 0
564-
ref_y = ref_y[:n_rows]
565-
tri_y = tri_y[:n_rows]
566540
if act_mxfp8:
567541
tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, target_dtype=torch.bfloat16, axis=-1).to(ref_y.dtype)
568542
ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1)
@@ -683,18 +657,18 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
683657
"split_k": split_k,
684658
"fused_scatter": fused_scatter,
685659
}
686-
n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1
660+
n_expts_tot, n_expts_act = 1, 1
687661
opt_flags.update_opt_flags_constraints(constraints)
688662

689663
weight_dtype, act_dtype = torch.float16, torch.float16
690664
if mode == "ragged":
691-
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
665+
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter,
692666
device=device)
693667
else:
694668
rdata = gindx = sindx = None
695669

696-
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device)
697-
x, w, bias, _, _ = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
670+
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot, device=device)
671+
x, w, bias, _, _ = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode,
698672
act_dtype, weight_dtype, False, requires_grad=False, device=device)
699673

700674
if mode == "batched":

python/triton_kernels/tests/test_routing.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,8 @@ def _assert_indx_equal(ref, tri):
5555
tri_expt_data = tri_routing_data.expt_data
5656
assert_equal(ref_expt_data.hist, tri_expt_data.hist)
5757
assert_equal(ref_expt_data.token_offs_raw, tri_expt_data.token_offs_raw)
58-
assert len(ref_expt_data.token_offs_pad) == len(tri_expt_data.token_offs_pad)
59-
assert len(ref_expt_data.block_pid_map) == len(tri_expt_data.block_pid_map)
60-
for block_m in ref_expt_data.token_offs_pad.keys():
61-
assert_equal(ref_expt_data.token_offs_pad[block_m], tri_expt_data.token_offs_pad[block_m])
62-
assert_equal(ref_expt_data.block_pid_map[block_m], tri_expt_data.block_pid_map[block_m])
58+
assert_equal(ref_expt_data.token_offs_pad_data, tri_expt_data.token_offs_pad_data)
59+
assert_equal(ref_expt_data.block_pid_map_data, tri_expt_data.block_pid_map_data)
6360

6461
assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot
6562
assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def make_kernel_args(data, block_m):
169169
return (
170170
expt_data.hist,
171171
expt_data.token_offs_raw,
172-
expt_data.token_offs_pad[block],
173-
expt_data.block_pid_map[block],
172+
expt_data.token_offs_pad(block),
173+
expt_data.block_pid_map(block),
174174
) + args
175175

176176

0 commit comments

Comments
 (0)