Skip to content

Commit 2509898

Browse files
authored
[KERNELS] support 32-bit inputs in topk.py (#6856)
1 parent e64cda2 commit 2509898

File tree

4 files changed

+22
-16
lines changed

4 files changed

+22
-16
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _query_gpu_specs():
3838

3939
gpu_specs = {
4040
"NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35},
41-
"HGX GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0},
41+
"NVIDIA GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0},
4242
"AMD Instinct MI300X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 5.3},
4343
"AMD Instinct MI325X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 6.0},
4444
}
@@ -219,10 +219,11 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
219219
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
220220
if SPECS is None:
221221
print("Current GPU has no specs provided, utilization is N/A")
222-
batch_ranges = [(1024, 32768, 1024)]
222+
batch_ranges_dense = [(1024, 32768, 1024)]
223+
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
223224
dense_dtypes = ["fp8", "fp8"]
224225
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
225-
roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
226-
roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
227-
roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
228-
roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
226+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
227+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
228+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
229+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")

python/triton_kernels/triton_kernels/routing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
6262
HIST_BLOCK_M = 64
6363
INDX_OFFS_BLOCK_M = 512
6464
MEMSET_BLOCK = 1024
65-
assert logits.dtype.itemsize == 2
6665
n_tokens, n_expts_tot = logits.shape
6766
n_gates = n_tokens * n_expts_act
6867
device = logits.device

python/triton_kernels/triton_kernels/topk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ def topk(x, k, dim=1, return_bitmatrix=True):
77
cdiv = lambda a, b: (a + b - 1) // b
88
BLOCK_M = 8
99
BLOCK_N = 128
10-
assert x.dtype.itemsize == 2
1110
assert x.ndim == 2
1211
assert x.shape[-1] < 32768
1312
assert dim == 1

python/triton_kernels/triton_kernels/topk_details/_topk.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
@triton.jit
66
def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
77
BLOCK_N: tl.constexpr):
8+
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
9+
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
10+
x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}")
11+
x_dbtype: tl.constexpr = tl.dtype(f"fp{2*x_nbits}")
812

913
# subtract 1 from loop iterations because we peel the first (masked) iteration:
1014
loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
@@ -15,8 +19,8 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
1519
# first iteration:
1620
X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
1721
x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
18-
x = (x.to(tl.uint16, bitcast=True).to(tl.int32) << 16) | offs_x_n[None, :]
19-
x = x.to(tl.float32, bitcast=True)
22+
x = (x.to(x_utype, bitcast=True).to(x_ultype) << x_nbits) | offs_x_n[None, :]
23+
x = x.to(x_dbtype, bitcast=True)
2024

2125
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
2226

@@ -26,8 +30,8 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
2630
X_ptrs -= BLOCK_N
2731
offs_x_n -= BLOCK_N
2832
x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
29-
x = (x.to(tl.uint16, bitcast=True).to(tl.int32) << 16) | offs_x_n[None, :]
30-
x = x.to(tl.float32, bitcast=True)
33+
x = (x.to(x_utype, bitcast=True).to(x_ultype) << x_nbits) | offs_x_n[None, :]
34+
x = x.to(x_dbtype, bitcast=True)
3135
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
3236

3337
return acc
@@ -43,18 +47,21 @@ def _topk(X, stride_xm, # inputs
4347
tl.static_assert(BLOCK_N % 32 == 0)
4448
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
4549
x_dtype: tl.constexpr = X.dtype.element_ty
50+
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
51+
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
52+
x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}")
4653

4754
# load logits
4855
offs_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
4956
mask_m = offs_m[:, None] < n_rows
5057
y = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N)
51-
y = y.to(tl.uint32, bitcast=True)
58+
y = y.to(x_ultype, bitcast=True)
5259

5360
# sort result in direction of ascending expert index
54-
y = (y << 16) | (y >> 16)
61+
y = (y << x_nbits) | (y >> x_nbits)
5562
y = tl.sort(y, dim=1)
56-
y_indices = y >> 16
57-
y_values = (y & 0x0000FFFF).to(tl.uint16).to(x_dtype, bitcast=True)
63+
y_indices = y >> x_nbits
64+
y_values = (y & ((1 << x_nbits) - 1)).to(x_utype).to(x_dtype, bitcast=True)
5865
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
5966

6067
# write back

0 commit comments

Comments
 (0)