Skip to content

Commit 273649e

Browse files
authored
[KERNELS] Support y_indx and uniform distribution (#8472)
1 parent 28533b1 commit 273649e

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

python/triton_kernels/tests/test_distributed.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@ def _make_expt_dict_for_mode(n_shards, n_expts_tot, affinity_mode):
2727
raise ValueError(f"Unknown affinity mode: {affinity_mode}") from exc
2828

2929

30+
def _make_y_indx_for_mode(n_tokens_global, n_expts_tot, n_expts_act, n_shards, affinity_mode, dev):
31+
y_indx_global = None
32+
if affinity_mode == "uniform":
33+
if n_expts_tot % n_shards != 0:
34+
raise ValueError("uniform affinity requires experts evenly divisible by shards")
35+
expts_per_rank = n_expts_tot // n_shards
36+
rounds = (n_expts_act + n_shards - 1) // n_shards
37+
if rounds > expts_per_rank:
38+
raise ValueError("round-robin selection exceeds experts available per shard")
39+
order = torch.arange(n_expts_act, device=dev, dtype=torch.int32)
40+
shard_order = order % n_shards
41+
intra_shard = order // n_shards
42+
round_robin_indx = (shard_order * expts_per_rank + intra_shard).to(torch.int16)
43+
y_indx_global = round_robin_indx.unsqueeze(0).expand(n_tokens_global, -1).contiguous()
44+
return y_indx_global
45+
46+
3047
# ------------------------------------------------------------
3148
# fixture
3249
# ------------------------------------------------------------
@@ -102,8 +119,8 @@ def test_make_expt_assignment(n_expts_shard, n_expts_tot, affinity_mode):
102119
# ------------------------------------------------------------
103120

104121

105-
def routing(logits, n_expts_act, all_gather=False):
106-
sparse_logits = topk(logits, n_expts_act, all_gather=all_gather)
122+
def routing(logits, n_expts_act, all_gather=False, y_indx=None):
123+
sparse_logits = topk(logits, n_expts_act, all_gather=all_gather, y_indx=y_indx)
107124
dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
108125
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
109126
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
@@ -115,17 +132,18 @@ def routing(logits, n_expts_act, all_gather=False):
115132
return routing_data, gather_idx, scatter_idx, sparse_logits.indx
116133

117134

118-
def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act):
119-
rdata, combine_indx, dispatch_indx, _ = routing(l_global, n_expts_act)
135+
def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act, y_indx=None):
136+
rdata, combine_indx, dispatch_indx, _ = routing(l_global, n_expts_act, y_indx=y_indx)
120137
y_global = matmul_ogs(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx)
121138
return y_global
122139

123140

124-
def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, expt_assignment, n_expts_act):
141+
def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, expt_assignment, n_expts_act,
142+
y_indx=None):
125143
rank = dist.get_rank()
126144
expt_map = expt_assignment.expt_map[rank, :]
127145
# active global logits (sparse)
128-
l_global_active = topk(l_dp_local, n_expts_act, apply_softmax=True, all_gather=True)
146+
l_global_active = topk(l_dp_local, n_expts_act, apply_softmax=True, all_gather=True, y_indx=y_indx)
129147
# expert histogram, dispatch/combine indx
130148
active_indx = l_global_active.indx
131149
expt_sizes = l_global_active.mask_metadata.col_sum
@@ -264,7 +282,15 @@ def _run_expert_sharding(rank, world_size, *, n_tokens, d_model, n_expts_tot, n_
264282
l_dp_local = l_global[first_token_indx:last_token_indx, :]
265283
# routing
266284
# test correctness
267-
y_global_ref = mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act)
285+
y_indx_global = _make_y_indx_for_mode(n_tokens_global, n_expts_tot, n_expts_act, n_shards, affinity_mode, dev)
286+
y_global_ref = mixture_of_expt_nosharded(
287+
x_global,
288+
l_global,
289+
w_global,
290+
b_global,
291+
n_expts_act,
292+
y_indx=y_indx_global,
293+
)
268294

269295
def run_mixture():
270296
return mixture_of_expt_epsharded(
@@ -274,6 +300,7 @@ def run_mixture():
274300
b_ep_local,
275301
expt_assignment,
276302
n_expts_act,
303+
y_indx=y_indx_global,
277304
)
278305

279306
# test cuda graph capture + replay with symmetric memory

python/triton_kernels/triton_kernels/topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def topk_forward(x, k, apply_softmax=True, dim=1, y_indx=None, n_rows=None, all_
3333
assert len(x.shape) == 2
3434
assert x.shape_max[-1] < 32768
3535
assert dim == 1
36-
assert not all_gather or not use_provided_indx
3736
n_rows, n_cols = x.shape
3837
n_rows_max, _ = x.shape_max
3938
dev = x.device
@@ -62,7 +61,8 @@ def topk_forward(x, k, apply_softmax=True, dim=1, y_indx=None, n_rows=None, all_
6261
)
6362
if all_gather:
6463
y_vals_hdl.barrier(channel=0)
65-
y_indx_hdl.barrier(channel=0)
64+
if y_indx_hdl is not None:
65+
y_indx_hdl.barrier(channel=0)
6666
bitmatrix_hdl.barrier(channel=0)
6767
bitmatrix_shape = [n_rows * dist.get_world_size() if all_gather else n_rows, n_cols]
6868
bitmatrix_shape_max = [n_rows_out_max, None]

python/triton_kernels/triton_kernels/topk_details/_topk_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _topk_forward(X, stride_xm, # inputs
115115
mask_m = offs_m[:, None] < n_rows
116116
if USE_PROVIDED_INDX:
117117
tl.static_assert(len(PeerYis) == 1)
118-
Yi_ptrs = PeerYis[0] + offs_m[:, None] * stride_ym + offs_y_n[None, :]
118+
Yi_ptrs = PeerYis[0] + (dst_offs_m + offs_m[:, None]) * stride_ym + offs_y_n[None, :]
119119
y_indices = tl.load(Yi_ptrs, mask=mask_m)
120120
Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
121121
y_values = tl.load(Xv_ptrs, mask=mask_m)

0 commit comments

Comments
 (0)