Skip to content

Commit 2ef33c6

Browse files
authored
[SWP] When num_stages = 2, do not pipeline indirect loads (#4721)
For indirect loads, we try to assign them to later stages ``` unsigned stagesBetweenLoads = ceil<unsigned>(numStages - 2, maxIndirectionLevel + 1); int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; schedule.insert(loadOp, stage, loadsClusters[indLevel]); ``` If numStages is 2, there is no later stage to assign the indirect loads to. The fix is to not pipeline the indirect loads. We also generalize to not pipeline an indirect load if the indirection level >= numStages - 1
1 parent c210764 commit 2ef33c6

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,18 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
606606
if (loadOpToIndLevelAndUse.empty())
607607
return {};
608608

609+
for (auto iter = loadOpToIndLevelAndUse.begin();
610+
iter != loadOpToIndLevelAndUse.end();) {
611+
auto iterNext = iter + 1;
612+
if (std::get<1>(*iter) >= numStages - 1)
613+
// We assume loads with different dist are assigned to different stages.
614+
// If numStages is 2, we will have no stage available for indirect loads
615+
// with dist >= 1. In general, when dist is equal to numStages - 1, we
616+
// should not pipeline it.
617+
loadOpToIndLevelAndUse.erase(iter);
618+
iter = iterNext;
619+
}
620+
609621
// Check which loads are good for pipelining, and assign them
610622
// memory layouts.
611623
llvm::MapVector<Operation *, LoadInfo> loadToInfo =
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=2 | FileCheck %s
2+
// CHECK-LABEL: @indirect_load_two_stages
3+
// CHECK: scf.for
4+
// CHECK: tt.dot
5+
// CHECK: tt.load
6+
// CHECK: async_copy_global_to_local
7+
// CHECK: async_copy_global_to_local
8+
9+
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}>
10+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
11+
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}>
12+
13+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
14+
tt.func public @indirect_load_two_stages(%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} {
15+
%c32_i32 = arith.constant 32 : i32
16+
%c16_i32 = arith.constant 16 : i32
17+
%cst = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked>
18+
19+
%0 = tt.get_program_id y : i32
20+
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
21+
%2 = tt.load %1 : !tt.ptr<i64>
22+
23+
%7 = tt.get_program_id x : i32
24+
%8 = arith.muli %7, %c16_i32 : i32
25+
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
26+
%15 = tt.splat %8 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
27+
%18 = arith.addi %15, %10 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
28+
29+
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
30+
%22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
31+
%34 = arith.extsi %arg12 : i32 to i64
32+
%35 = arith.muli %2, %34 : i64
33+
%36 = tt.addptr %arg2, %35 : !tt.ptr<f32>, i64
34+
35+
%47 = tt.splat %arg4 : !tt.ptr<i64> -> tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
36+
%48 = tt.addptr %47, %20 : tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
37+
38+
%59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
39+
%61 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
40+
%63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3>
41+
42+
%85 = arith.extsi %22 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
43+
%107 = tt.splat %36 : !tt.ptr<f32> -> tensor<32x128x!tt.ptr<f32>, #blocked3>
44+
%108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3>
45+
%109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3>
46+
47+
%101 = tt.splat %arg5 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked1>
48+
%111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 {
49+
%129 = tt.splat %arg28 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
50+
%160 = tt.addptr %48, %129 : tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
51+
%161 = tt.load %160 : tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
52+
%162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1>
53+
%163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1>
54+
%182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr<f32>, #blocked1>, tensor<16x32xi64, #blocked1>
55+
%183 = tt.load %182 : tensor<16x32x!tt.ptr<f32>, #blocked1>
56+
57+
%197 = arith.extsi %arg28 : i32 to i64
58+
%198 = tt.splat %197 : i64 -> tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
59+
%199 = arith.addi %198, %85 : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
60+
%200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3>
61+
%201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3>
62+
%202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3>
63+
%203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3>
64+
%204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr<f32>, #blocked3>, tensor<32x128xi64, #blocked3>
65+
%209 = tt.load %204 : tensor<32x128x!tt.ptr<f32>, #blocked3>
66+
67+
%210 = triton_gpu.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
68+
%211 = triton_gpu.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
69+
%212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked>
70+
scf.yield %212 : tensor<16x128xf32, #blocked>
71+
}
72+
%112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3>
73+
%113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3>
74+
%114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3>
75+
%115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3>
76+
%116 = arith.extsi %arg17 : i32 to i64
77+
%117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3>
78+
%118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3>
79+
%119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3>
80+
%120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3>
81+
%121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3>
82+
%122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3>
83+
%123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3>
84+
%124 = tt.splat %arg7 : !tt.ptr<f32> -> tensor<16x128x!tt.ptr<f32>, #blocked3>
85+
%125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr<f32>, #blocked3>, tensor<16x128xi64, #blocked3>
86+
%128 = triton_gpu.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3>
87+
tt.store %125, %128 : tensor<16x128x!tt.ptr<f32>, #blocked3>
88+
tt.return
89+
}
90+
}

0 commit comments

Comments
 (0)