Skip to content

Commit 4107453

Browse files
authored
[PIPELINER] tweak pipeline heuristic (#5247)
Don't pipeline the dot accumulator in the default heuristic. In the finer grain control will allow user to decide.
1 parent e3ab295 commit 4107453

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) {
5656
distance++;
5757
}
5858
for (Value operand : op->getOperands()) {
59+
if (op->hasTrait<OpTrait::DotLike>()) {
60+
// Heuristic: only pipeline A and B operands of the dot op.
61+
if (operand == op->getOperand(2))
62+
continue;
63+
}
5964
Value v = operand;
6065
Operation *defOp = v.getDefiningOp();
6166
if (defOp && defOp->getBlock() == op->getBlock()) {

test/TritonGPU/loop-schedule.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 | FileCheck %s
2+
3+
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
6+
#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}>
7+
#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}>
8+
#CLs0 = #triton_gpu.slice<{parent=#C, dim=0}>
9+
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
10+
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
11+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} {
12+
// CHECK-LABLE: @matmul_loop_load_acc
13+
// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32}
14+
// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32}
15+
// CHECK: tt.load %{{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32}
16+
// CHECK: tt.dot {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32}
17+
tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index,
18+
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
19+
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32},
20+
%C : !tt.ptr<f32> {tt.divisibility = 16 : i32},
21+
%c_init: tensor<128x128xf32, #C>) -> tensor<128x128xf32, #C> {
22+
23+
// A ptrs
24+
%a_ptr_splat = tt.splat %A : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #AL>
25+
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
26+
%a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL>
27+
%a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL>
28+
%a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
29+
// B ptrs
30+
%b_ptr_splat = tt.splat %B : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #BL>
31+
%b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0>
32+
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL>
33+
%b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL>
34+
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
35+
// C ptrs
36+
%c_ptr_splat = tt.splat %C : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #C>
37+
%c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #CLs0>
38+
%c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32, #CLs0> -> tensor<1x128xi32, #C>
39+
%c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32, #C> -> tensor<128x128xi32, #C>
40+
%c_ptr_init = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xi32, #C>
41+
42+
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
43+
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
44+
%c_off = arith.constant dense<4> : tensor<128x128xi32, #C>
45+
46+
%loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %c_ptr = %c_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xf32, #C>) {
47+
%a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>, #AL>
48+
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A>
49+
%b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr<f16>, #BL>
50+
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B>
51+
%c_ = tt.load %c_ptr : tensor<128x128x!tt.ptr<f32>, #C>
52+
%c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
53+
54+
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
55+
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
56+
%next_c_ptr = tt.addptr %c_ptr, %c_off : tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xi32, #C>
57+
scf.yield %next_a_ptr, %next_b_ptr, %next_c_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xf32, #C>
58+
}
59+
tt.return %loop#3: tensor<128x128xf32, #C>
60+
}
61+
}

0 commit comments

Comments
 (0)