Skip to content

Commit fce3e6d

Browse files
authored
[TritonGPU] Fix incorrect mask operand used in for loop pipeliner (#5161)
When the OOB values for a `tt.load` are non-zero, the for loop pipeliner needs to generate an `arith.select` to mask the loaded values with the default OOB value. However, if the load memory requires a layout change, the wrong mask operand was being passed to the `arith.select`, causing a shape mismatch. The fix is to just use the same mask operand of the origianl `tt.load` op. Fixes triton-lang/triton#4739
1 parent 9883a9b commit fce3e6d

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) {
102102
return firstUser;
103103
}
104104

105-
static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
105+
static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
106106
Value insertIdx, Value extractIdx,
107107
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
108108
int numStages, int maxClusterId) {
@@ -192,8 +192,10 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
192192
Value other = loadOp.getOther();
193193
if (other && !isZeroConst(other)) {
194194
auto select = builder.createWithStage<arith::SelectOp>(
195-
loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(), mask,
196-
sharedLoad.getResult(), other);
195+
loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(),
196+
// Use the mask operand from the original load, not the one with a
197+
// potentially transformed layout.
198+
loadOp.getMask(), sharedLoad.getResult(), other);
197199
result = select->getResults();
198200
}
199201

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: triton-opt %s -tritongpu-pipeline | FileCheck %s
2+
3+
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
4+
5+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
6+
7+
// CHECK-LABEL: @softmax_kernel
8+
tt.func public @softmax_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
9+
%cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked>
10+
%0 = tt.get_program_id x : i32
11+
%1 = tt.get_num_programs x : i32
12+
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
13+
%3 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked>
14+
// CHECK: [[MASK:%.*]] = arith.cmpi slt, {{.*}} tensor<128xi32,
15+
%4 = arith.cmpi slt, %2, %3 : tensor<128xi32, #blocked>
16+
// CHECK: scf.for
17+
scf.for %arg6 = %0 to %arg4 step %1 : i32 {
18+
%5 = tt.splat %arg1 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked>
19+
%6 = tt.addptr %5, %2 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr<f32>, #blocked>, tensor<128xi32, #blocked>
20+
// CHECK: [[RESULT:%.*]] = triton_gpu.local_load
21+
// CHECK-NEXT: arith.select [[MASK]], [[RESULT]], %cst
22+
%7 = tt.load %6, %4, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr<f32>, #blocked>
23+
%8 = tt.splat %arg0 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked>
24+
%9 = tt.addptr %8, %2 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr<f32>, #blocked>, tensor<128xi32, #blocked>
25+
tt.store %9, %7, %4 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr<f32>, #blocked>
26+
} {tt.num_stages = 2 : i32}
27+
tt.return
28+
}
29+
30+
}

0 commit comments

Comments
 (0)