Skip to content

Commit a86c5e7

Browse files
ThomasRaouxmakslevental
authored andcommitted
[Pipelinier] Fix mmav3 pipelining (triton-lang#5844)
Make sure we allocate the right number of slices when doing mmav3 pipelining.
1 parent be04dd2 commit a86c5e7

File tree

4 files changed

+45
-41
lines changed

4 files changed

+45
-41
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ StringRef getAMDArch(Operation *module);
200200
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
201201
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
202202

203-
bool canUseMMAv3Pipelining(Operation *loadOp);
204-
205203
// Convert \param op operands and results to layout \param encoding.
206204
void convertOpEncoding(Attribute encoding, Operation *op);
207205

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,16 @@ getTransitiveUserInBlock(Operation *baseOp, scf::ForOp &forOp) {
476476
return users;
477477
}
478478

479+
static bool isMMAv3Buffer(Operation *loadOp) {
480+
if (!loadOp->hasOneUse())
481+
return false;
482+
Operation *user = *loadOp->getUsers().begin();
483+
if (auto alloc = dyn_cast<ttg::LocalAllocOp>(user)) {
484+
return isa<ttg::NVMMASharedEncodingAttr>(alloc.getType().getEncoding());
485+
}
486+
return false;
487+
}
488+
479489
static llvm::MapVector<Operation *, LoadInfo>
480490
assignMemoryLayouts(scf::ForOp &forOp,
481491
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
@@ -517,16 +527,10 @@ assignMemoryLayouts(scf::ForOp &forOp,
517527
loadsToPipeline.insert(&op);
518528
LoadInfo loadInfo;
519529
for (auto use : users) {
520-
// By default we will try pipelining with load to registers at the end.
521-
// For mmav3 we can try leaving the operands in shared memory.
522-
bool mmav3Shmem = false;
523530
if (isa<mlir::triton::DotOpInterface>(use)) {
524531
LDBG("set shared encoding with dot user: " << *use);
525532
auto dot = dyn_cast<tt::DotOp>(use);
526-
bool isMMAv3v5Dot = isa<ttng::WarpGroupDotOp, ttng::TCGen5MMAOp,
527-
ttng::TCGen5MMAScaledOp>(use);
528-
mmav3Shmem = canUseMMAv3Pipelining(&op) && isMMAv3v5Dot;
529-
533+
bool mmav3Shmem = isMMAv3Buffer(&op);
530534
loadInfo.usedByDot = true;
531535
loadInfo.isMMAv3Shared = mmav3Shmem;
532536

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,38 +1045,6 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
10451045
return attr;
10461046
}
10471047

1048-
bool canUseMMAv3Pipelining(Operation *loadOp) {
1049-
Operation *user = *loadOp->getUsers().begin();
1050-
while (isa<triton::TransOp, triton::ReshapeOp>(user)) {
1051-
if (!user->hasOneUse())
1052-
return false;
1053-
user = *user->getUsers().begin();
1054-
}
1055-
if (!user)
1056-
return false;
1057-
1058-
if (auto alloc = dyn_cast<ttg::LocalAllocOp>(user)) {
1059-
auto sharedEnc =
1060-
dyn_cast<ttg::NVMMASharedEncodingAttr>(alloc.getType().getEncoding());
1061-
1062-
if (!sharedEnc)
1063-
return false;
1064-
1065-
// MMA V3 case.
1066-
SmallVector<unsigned> newOrder = getOrder(sharedEnc);
1067-
auto ty = cast<RankedTensorType>(loadOp->getResultTypes()[0]);
1068-
auto oldOrder = ttg::getOrder(ty.getEncoding());
1069-
1070-
// The operand of MMAv3 is in SharedEncoding and its order should not
1071-
// be changed after FuseTranspositions Pass. So we only pipeline the
1072-
// load if the order of the loaded BlockedEncoding is the same as the
1073-
// order of the SharedEncoding it is converted to.
1074-
return oldOrder == newOrder;
1075-
} else {
1076-
return false;
1077-
}
1078-
}
1079-
10801048
namespace {
10811049

10821050
/// Detect dead arguments in scf.for op by assuming all the values are dead and
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline | FileCheck %s --check-prefixes=CHECK
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 8]}>
6+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
7+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}>
8+
#smem = #ttg.shared_memory
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
10+
// CHECK-LABEL: @pipeline_load_mmav3
11+
tt.func public @pipeline_load_mmav3(%arg0: tensor<256x128xf32, #mma>, %arg1: tensor<256x32x!tt.ptr<f32>, #blocked>, %arg2: tensor<32x128x!tt.ptr<f32>, #blocked1>, %arg3: tensor<256x32xi32, #blocked>, %arg4: tensor<32x128xi32, #blocked1>) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f32>, #blocked>, tensor<32x128x!tt.ptr<f32>, #blocked1>) {
12+
%c0_i32 = arith.constant 0 : i32
13+
%c1_i32 = arith.constant 1 : i32
14+
%c128_i32 = arith.constant 128 : i32
15+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<4x256x32xf32
16+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<4x32x128xf32
17+
%0:3 = scf.for %arg5 = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f32>, #blocked>, tensor<32x128x!tt.ptr<f32>, #blocked1>) : i32 {
18+
// CHECK: ttg.memdesc_subview {{.*}} : !ttg.memdesc<4x256x32xf32
19+
// CHECK: ttg.async_wait {{.*}} {num = 4 : i32}
20+
// CHECK: ttg.memdesc_subview {{.*}} : !ttg.memdesc<4x32x128xf32
21+
// CHECK: ttng.warp_group_dot {{.*}} {inputPrecision = 0 : i32, isAsync = true}
22+
// CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32}
23+
%1 = tt.load %arg7 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<256x32x!tt.ptr<f32>, #blocked>
24+
%2 = ttg.local_alloc %1 {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared, #smem>
25+
%3 = tt.load %arg8 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<32x128x!tt.ptr<f32>, #blocked1>
26+
%4 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<32x128xf32, #blocked1>) -> !ttg.memdesc<32x128xf32, #shared1, #smem>
27+
%5 = ttng.warp_group_dot %2, %4, %arg6 {inputPrecision = 0 : i32, loop.cluster = 0 : i32, loop.stage = 3 : i32} : !ttg.memdesc<256x32xf32, #shared, #smem> * !ttg.memdesc<32x128xf32, #shared1, #smem> -> tensor<256x128xf32, #mma>
28+
%6 = tt.addptr %arg7, %arg3 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<256x32x!tt.ptr<f32>, #blocked>, tensor<256x32xi32, #blocked>
29+
%7 = tt.addptr %arg8, %arg4 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<32x128x!tt.ptr<f32>, #blocked1>, tensor<32x128xi32, #blocked1>
30+
scf.yield %5, %6, %7 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f32>, #blocked>, tensor<32x128x!tt.ptr<f32>, #blocked1>
31+
} {tt.num_stages = 4 : i32}
32+
tt.return %0#0, %0#1, %0#2 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f32>, #blocked>, tensor<32x128x!tt.ptr<f32>, #blocked1>
33+
}
34+
}

0 commit comments

Comments
 (0)