Skip to content

Commit a1316c4

Browse files
Fix BitmatrixMetadata col/row_sorted_indx (#8599)
`col/row_sorted_indx` were passed to the constructor in wrong order. The user side (`combine_indx` and `dispatch_indx`) also points to the wrong index so the error was cancelled. This PR fixes the constructor to use the right order and updates the user side.
1 parent dc4efec commit a1316c4

File tree

4 files changed

+23
-15
lines changed

4 files changed

+23
-15
lines changed

python/triton_kernels/bench/distributed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def routing(
165165
)
166166
active_indx = logits_global.indx
167167
expt_sizes = logits_global.mask_metadata.col_sum
168-
dispatch_indx = logits_global.mask_metadata.col_sorted_indx
169-
combine_indx = logits_global.mask_metadata.row_sorted_indx
168+
dispatch_indx = logits_global.mask_metadata.row_sorted_indx
169+
combine_indx = logits_global.mask_metadata.col_sorted_indx
170170
logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0])
171171
x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx)
172172
logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map)
@@ -184,8 +184,8 @@ def routing(
184184
else:
185185
# If mode is not specified or we have a single process, we do single-GPU routing.
186186
logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first)
187-
dispatch_indx = logits.mask_metadata.col_sorted_indx
188-
combine_indx = logits.mask_metadata.row_sorted_indx
187+
dispatch_indx = logits.mask_metadata.row_sorted_indx
188+
combine_indx = logits.mask_metadata.col_sorted_indx
189189
ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0])
190190
gate_scal = logits.vals.flatten()[combine_indx]
191191
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act,

python/triton_kernels/tests/test_distributed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def test_make_expt_assignment(n_expts_shard, n_expts_tot, affinity_mode):
120120

121121
def routing(logits, n_expts_act, all_gather=False, y_indx=None):
122122
sparse_logits = topk(logits, n_expts_act, all_gather=all_gather, y_indx=y_indx)
123-
dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
124-
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
123+
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
124+
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
125125
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
126126
gate_scal = sparse_logits.vals.flatten()[combine_indx]
127127
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, logits.shape[-1], n_expts_act,
@@ -146,8 +146,8 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex
146146
# expert histogram, dispatch/combine indx
147147
active_indx = l_global_active.indx
148148
expt_sizes = l_global_active.mask_metadata.col_sum
149-
dispatch_indx = l_global_active.mask_metadata.col_sorted_indx
150-
combine_indx = l_global_active.mask_metadata.row_sorted_indx
149+
dispatch_indx = l_global_active.mask_metadata.row_sorted_indx
150+
combine_indx = l_global_active.mask_metadata.col_sorted_indx
151151
# ragged tensor metadata
152152
x_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0])
153153
# convert x from dp-local to expert-sorted, ep-local

python/triton_kernels/tests/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def alloc_rand_like(x):
4242
def init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, device="cuda"):
4343
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True)
4444
sparse_logits = topk(logits, n_expts_act)
45-
dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
46-
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
45+
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
46+
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
4747
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
4848
routing_data = RoutingData(None, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, ragged_batch_metadata)
4949
gather_idx = GatherIndx(combine_indx, dispatch_indx) if do_gather else None

python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ class BitmatrixMetadata:
1414
1 1 1 0 0 0 1
1515
0 0 1 0 1 0 0]
1616
`col_sum` = [1 2 3 0 2 2 1]
17-
`row_sorted_indx` = cat([3 6 8], [1 9], [0 2 4 10], [5 7])
1817
`col_sorted_indx` = cat([5], [3 6], [0 7], [], [9 1 10], [2 4], [8])
18+
`row_sorted_indx` = cat([3 6 8], [1 9], [0 2 4 10], [5 7])
1919
"""
2020
# the number of entries equal to 1 in each column
2121
col_sum: torch.Tensor
22-
# indices of nonzero values numbered col-major, grouped by rows, concatenated
23-
row_sorted_indx: torch.Tensor
2422
# indices of nonzero values numbered row-major, grouped by cols, concatenated
2523
col_sorted_indx: torch.Tensor
24+
# indices of nonzero values numbered col-major, grouped by rows, concatenated
25+
row_sorted_indx: torch.Tensor
2626

2727

2828
# `make_bitmatrix_metadata`: entry point for optimized implementation
@@ -143,7 +143,11 @@ def make_bitmatrix_metadata(nonzero_indx, bitmatrix):
143143
col_offs, #
144144
TOKS_PER_ROW=toks_per_row, BLOCK_PER_TOK=PARTIAL_BLOCK_M, #
145145
)
146-
return BitmatrixMetadata(col_sum, col_sorted_indx, row_sorted_indx)
146+
return BitmatrixMetadata(
147+
col_sum=col_sum,
148+
col_sorted_indx=col_sorted_indx,
149+
row_sorted_indx=row_sorted_indx,
150+
)
147151

148152

149153
# `make_bitmatrix_metadata_torch`: entry point for reference implementation
@@ -157,4 +161,8 @@ def make_bitmatrix_metadata_torch(nonzero_indx, bitmatrix):
157161
col_sorted_indx = pad(torch.argsort(nonzero_indx[nonzero_indx != -1], stable=True), nonzero_indx.numel())
158162
row_sorted_indx = pad(torch.argsort(col_sorted_indx[col_sorted_indx != -1], stable=True), nonzero_indx.numel())
159163
col_sum = torch.histc(nonzero_indx, bins=n_batches, max=n_batches - 1).int()
160-
return BitmatrixMetadata(col_sum, col_sorted_indx, row_sorted_indx)
164+
return BitmatrixMetadata(
165+
col_sum=col_sum,
166+
col_sorted_indx=col_sorted_indx,
167+
row_sorted_indx=row_sorted_indx,
168+
)

0 commit comments

Comments
 (0)