Skip to content

Commit e120c12

Browse files
committed
Merge commit '981e987eed9053b952f81153bc0779c99d8c642e'
2 parents 54214d5 + 981e987 commit e120c12

File tree

7 files changed

+201
-213
lines changed

7 files changed

+201
-213
lines changed

bench/tests/test_routing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def bench_routing():
8383
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)
8484
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
8585
proton.finalize()
86+
try:
87+
import os
88+
os.system("proton-viewer -m time/ms routing.hatchet")
89+
except:
90+
pass
8691

8792

8893
if __name__ == "__main__":

bench/triton_bench/routing.py

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histog
1313
offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
1414
mask_n = offs_n < hist_size
1515
hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
16-
tok_starts = tl.cumsum(hist2, 0) + x
16+
tok_starts = tl.cumsum(hist2, 0) - hist2 + x
1717
x += tl.sum(hist2, 0)
18-
tl.store(FinalExpertOffs, 0)
19-
tl.store(FinalExpertOffs + 1 + offs_n, tok_starts, mask=mask_n)
18+
tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
2019
offs_n += BLOCK_N
2120

2221

@@ -52,51 +51,33 @@ def _keyed_add(x, y):
5251

5352

5453
@triton.jit
55-
def _count_previous(x):
56-
"""
57-
Input x : uint16[..., N]
58-
Output y : uint32[..., N]
59-
semantics : y[..., i] = sum_j((x[..., j] == x[..., i]) & (j < i))
60-
credits: @apgoucher
61-
"""
54+
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, n_gates,
55+
BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
6256

63-
BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
64-
BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
57+
pid_m = tl.program_id(0)
6558

66-
# reduce to two-dimensional case:
67-
y = tl.reshape(x, [BATCHES, BLOCK_N]).to(tl.uint32)
59+
tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
6860

69-
tl.static_assert(BLOCK_N <= 32768, "compute_run_lengths requires axis to have length <= 32768")
61+
local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
62+
offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
63+
expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
7064

71-
# sort (expert, position) ordered pairs to perform an argsort:
72-
kv_pairs = ((y << 16) | tl.arange(0, BLOCK_N)[None, :]).to(tl.uint32)
73-
sorted_kv_pairs = tl.sort(kv_pairs, 1)
65+
# stable-sort by expert ID:
66+
kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
67+
kv_pairs = tl.sort(kv_pairs, 0)
68+
expert = kv_pairs >> 16
69+
offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff)
70+
mask = expert != 0xffff
71+
gate_scal = tl.load(ExptScal + offs, mask=mask)
7472

7573
# compute run lengths in expert-sorted order:
76-
x = (sorted_kv_pairs & 0xffff0000 | 0x00000001)
77-
expts_and_inclusive_run_lengths = tl.associative_scan(x, 1, _keyed_add)
74+
x = (kv_pairs & 0xffff0000 | 0x00000001)
75+
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
7876
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
7977

80-
# undo permutation by doing another sort
81-
# TODO rewrite this when tl.scatter becomes available
82-
kv_pairs = ((sorted_kv_pairs << 16) | exclusive_run_lengths).to(tl.uint32)
83-
unsorted_run_lengths = tl.sort(kv_pairs) & 0xffff
84-
85-
res = tl.reshape(unsorted_run_lengths, x.shape)
86-
return res
87-
78+
gates = tl.load(PartialOffs + pid_m * stride_pm + expert, mask=(expert != 0xffff))
79+
gates += exclusive_run_lengths
8880

89-
@triton.jit
90-
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, n_gates,
91-
BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
92-
pid_m = tl.program_id(0)
93-
offs = pid_m * BLOCK_M * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT * BLOCK_M)
94-
mask = offs < n_gates
95-
indx = tl.load(ExptIndx + offs, mask=mask)
96-
mask = mask & (indx != -1)
97-
gates = tl.load(PartialOffs + pid_m * stride_pm + indx, mask=mask)
98-
gates += tl.reshape(_count_previous(indx), [BLOCK_M * N_EXPTS_ACT])
99-
gate_scal = tl.load(ExptScal + offs, mask=mask)
10081
tl.store(ScatterIndx + offs, gates, mask=mask)
10182
tl.store(GatherIndx + gates, offs, mask=mask)
10283
tl.store(GateScal + gates, gate_scal, mask=mask)
@@ -117,15 +98,16 @@ def _routing_clear_bitmatrix(Bitmatrix, stride_bm, shape_bn, cutoff, BLOCK_N: tl
11798

11899

119100
@triton.jit
120-
def _routing_memset_indx(Indx0, Indx1, size, sentinel, BLOCK: tl.constexpr):
101+
def _routing_memset_indx(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
102+
BLOCK_N: tl.constexpr):
121103
pid = tl.program_id(0)
122-
buf = tl.program_id(1)
123-
offs = pid * BLOCK + tl.arange(0, BLOCK)
124-
mask = offs < size
125-
if buf == 0:
126-
tl.store(Indx0 + offs, sentinel, mask=mask)
127-
if buf == 1:
128-
tl.store(Indx1 + offs, sentinel, mask=mask)
104+
105+
if pid == 0:
106+
_routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
107+
else:
108+
offs = (pid - 1) * BLOCK + tl.arange(0, BLOCK)
109+
mask = offs < size
110+
tl.store(Indx + offs, sentinel, mask=mask)
129111

130112

131113
@dataclass
@@ -204,22 +186,15 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
204186
# perform compaction to update expt_scal / expt_indx
205187
hist, partial_hist = sum(bitmatrix, partials_block_size=HIST_BLOCK_M, dim=0)
206188
# scratchpad
207-
expt_offs = torch.empty(n_expts_tot + 1, dtype=torch.int32, device=device)
189+
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
208190
indx_offs = torch.empty((cdiv(n_tokens, HIST_BLOCK_M), n_expts_tot), dtype=torch.int32, device=device)
191+
combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device)
209192
# output
210-
topk_indx = torch.empty(n_gates, dtype=torch.int32, device=device)
211-
gate_indx = torch.empty(n_gates, dtype=torch.int32, device=device)
193+
topk_indx = combined_indx[:n_gates]
194+
gate_indx = combined_indx[n_gates:]
212195
gate_scal = torch.empty(n_gates, dtype=logits.dtype, device=device)
213-
_routing_memset_indx[(cdiv(n_gates, MEMSET_BLOCK), 2)](
214-
topk_indx,
215-
gate_indx,
216-
n_gates,
217-
-1,
218-
BLOCK=MEMSET_BLOCK,
219-
)
220-
_routing_compute_expt_offs[(1, )](
221-
hist, expt_offs, hist.shape[0], BLOCK_N=512 # tunable parameters
222-
)
196+
_routing_memset_indx[(cdiv(n_gates * 2, MEMSET_BLOCK) + 1, )](combined_indx, n_gates * 2, -1, MEMSET_BLOCK, hist,
197+
expt_offs, hist.shape[0], BLOCK_N=512)
223198
_routing_compute_indx_offs[(n_expts_tot, )](
224199
expt_offs, partial_hist, # inputs
225200
indx_offs, partial_hist.shape[0], partial_hist.stride(0), # outputs

python/test/unit/language/test_matmul.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,10 @@ def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS
399399
if not is_cuda():
400400
return
401401

402-
# Pipelining of dot_scaled requires tmem_copy to be used, which in turn
403-
# requires the scales to be in the blocked layout in global memory.
404-
assert out.asm["ttgir"].count("ttng.tc_gen5_mma") == 1
402+
if is_cuda():
403+
# Pipelining of dot_scaled requires tmem_copy to be used, which in turn
404+
# requires the scales to be in the blocked layout in global memory.
405+
assert out.asm["ttgir"].count("ttng.tc_gen5_mma") == 1
405406

406407

407408
def _knob_promote_lhs_to_tmem(monkeypatch):

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck --check-prefix=GFX942 %s
2-
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefix=GFX950 %s
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck --check-prefixes=COMMON,GFX942 %s
2+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefixes=COMMON,GFX950 %s
33

44
// CHECK-LABEL: f16_to_f32
55
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -32,15 +32,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
3232
// GFX942-COUNT-8: llvm.fptrunc %{{.+}} : f32 to f16
3333
// GFX950-COUNT-4: llvm.fptrunc %{{.+}} : vector<2xf32> to vector<2xf16>
3434
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
35-
// GFX942-COUNT-4: rocdl.cvt.pkrtz
36-
// GFX950-COUNT-4: rocdl.cvt.pkrtz
35+
// COMMON-COUNT-4: rocdl.cvt.pkrtz
3736
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
3837
tt.return
3938
}
4039
}
4140

4241
// -----
4342

43+
// CHECK-LABEL: f32_to_f16_single_value
44+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
45+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
46+
tt.func @f32_to_f16_single_value(%arg0: tensor<1x128xf32, #blocked>) {
47+
// COMMON: llvm.fptrunc %{{.+}} : f32 to f16
48+
// COMMON-NOT: llvm.fptrunc
49+
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1x128xf32, #blocked> -> tensor<1x128xf16, #blocked>
50+
// COMMON: rocdl.cvt.pkrtz
51+
// COMMON-NOT: rocdl.cvt.pkrtz
52+
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<1x128xf32, #blocked> -> tensor<1x128xf16, #blocked>
53+
tt.return
54+
}
55+
}
56+
57+
// -----
58+
4459
// CHECK-LABEL: downcast_to_f8
4560
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
4661
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {

test/TritonGPU/amd/mfma-double-rate.mlir

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6161

6262
// -----
6363

64-
// When kWidth is set to 4, generate single rated mfma instructions.
65-
// In a future PR, such cases will still generate double rated mfma instructions with kWidth = 4.
64+
// When kWidth is set to 4, still generate double rated mfma instructions.
6665

6766
// CHECK-LABEL:mfma_16x16x32_f16
6867

@@ -74,7 +73,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
7473
%q: tensor<128x128xf16, #dotOp0>,
7574
%k: tensor<128x128xf16, #dotOp1>) {
7675
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
77-
// CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (vector<4xf16>, vector<4xf16>
76+
// CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>
7877
%qk = tt.dot %q, %k, %cst : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
7978
tt.return
8079
}
@@ -92,7 +91,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9291
%q: tensor<128x128xbf16, #dotOp0>,
9392
%k: tensor<128x128xbf16, #dotOp1>) {
9493
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
95-
// CHECK: rocdl.mfma.f32.16x16x16bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>
94+
// CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>
9695
%qk = tt.dot %q, %k, %cst : tensor<128x128xbf16, #dotOp0> * tensor<128x128xbf16, #dotOp1> -> tensor<128x128xf32, #mma>
9796
tt.return
9897
}
@@ -110,7 +109,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
110109
%q: tensor<128x128xf16, #dotOp0>,
111110
%k: tensor<128x128xf16, #dotOp1>) {
112111
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
113-
// CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (vector<4xf16>, vector<4xf16>
112+
// CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16>
114113
%qk = tt.dot %q, %k, %cst : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
115114
tt.return
116115
}
@@ -128,7 +127,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
128127
%q: tensor<128x128xbf16, #dotOp0>,
129128
%k: tensor<128x128xbf16, #dotOp1>) {
130129
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
131-
// CHECK: rocdl.mfma.f32.32x32x8bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>
130+
// CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>
132131
%qk = tt.dot %q, %k, %cst : tensor<128x128xbf16, #dotOp0> * tensor<128x128xbf16, #dotOp1> -> tensor<128x128xf32, #mma>
133132
tt.return
134133
}

0 commit comments

Comments
 (0)