Skip to content

Commit 75c2972

Browse files
authored
[SWP] Remove redundant SMEM encoding creation for MMAv3 (triton-lang#5640)
When we determine the SMEM encoding for a multi-buffered SMEM, we should reuse the encoding of the operand SMEM created by `AccelerateMatmul`. We do have such logic in the code, but currently there is additional MMAv3-specific code path before it that creates a fresh encoding which, in practice, always coincides with the existing operand encoding. https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp#L337-L361 The exception to this is multi-buffering of TMA load. `AccelerateMatmul` may create an encoding [whose `order` is an transpose of the register `order`]( https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#L151-L159). We cannot use such encoding as the destination of TMA. So for TMA load, we always create a new encoding that's known to be compatible to it in SWP . (When TMA and MMA operand encodings are different and the TMA one is not compatible with MMA, e.g. MMAv3 with row-major fp8 RHS, SWP ends up making an invalid program due to the overwriting by the TMA layout. We should not pipeline TMA load in such case. This is a bug that should be fixed) This work is mostly nit for the current main, but it is motivated for a case where we want to create a new kind of SMEM encoding representing a more complicated layout. Ideally, we only want to do that once in `AccelerateMatmul` and reuse that in SWP rather than repeating the same code there. cc @ThomasRaoux @pawelszczerbuk @csullivan @mbrookhart <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Masahiro Masuda <[email protected]>
1 parent 7fffa0d commit 75c2972

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) {
319319
}
320320

321321
static std::optional<ttg::SharedEncodingAttr>
322-
getSharedEncoding(Operation *loadOp, bool isMMAV3Shared) {
322+
getSharedEncoding(Operation *loadOp, bool isTMALoad) {
323323
auto ty = cast<RankedTensorType>(loadOp->getResultTypes()[0]);
324324
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
325325
auto blockedOrder = ttg::getOrder(ty.getEncoding());
@@ -334,7 +334,10 @@ getSharedEncoding(Operation *loadOp, bool isMMAV3Shared) {
334334
} else {
335335
order = blockedOrder;
336336
}
337-
if (isMMAV3Shared) {
337+
338+
if (isTMALoad) {
339+
// For TMA, the encoding compatible with it takes precedence over local
340+
// alloc created for the MMA operand.
338341
return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order,
339342
ctaLayout, ty.getElementType());
340343
}
@@ -487,6 +490,7 @@ assignMemoryLayouts(scf::ForOp &forOp,
487490
}
488491
});
489492

493+
bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp>(op);
490494
loadsToPipeline.insert(&op);
491495
LoadInfo loadInfo;
492496
for (auto use : users) {
@@ -501,12 +505,9 @@ assignMemoryLayouts(scf::ForOp &forOp,
501505
loadInfo.isMMAv3Registers =
502506
(mmaLoadType == MMALoadType::Registers) && warpGroupDot;
503507

504-
if (loadInfo.isMMAv3Shared) {
505-
loadInfo.sharedEncoding =
506-
getSharedEncoding(&op, /*loadIsMMAv3=*/true).value_or(nullptr);
507-
} else if (isa<tt::ExperimentalDescriptorLoadOp>(op)) {
508+
if (loadInfo.isMMAv3Shared || isTMALoad) {
508509
loadInfo.sharedEncoding =
509-
getSharedEncoding(&op, /*loadIsMMAv3=*/true).value_or(nullptr);
510+
getSharedEncoding(&op, isTMALoad).value_or(nullptr);
510511
} else if (loadInfo.isMMAv3Registers || dot) {
511512
bool incompatible = false;
512513
loadInfo.sharedEncoding =
@@ -520,8 +521,7 @@ assignMemoryLayouts(scf::ForOp &forOp,
520521
if (!loadInfo.sharedEncoding && !isa<ttng::WarpGroupDotOp>(use)) {
521522
LDBG("try generic shared encoding");
522523
loadInfo.sharedEncoding =
523-
getSharedEncoding(&op, /*isMMAV3=*/loadInfo.isMMAv3Shared)
524-
.value_or(nullptr);
524+
getSharedEncoding(&op, isTMALoad).value_or(nullptr);
525525
if (auto loadOp = dyn_cast<tt::LoadOp>(op))
526526
loadInfo.blockedEncoding =
527527
getBlockedEncoding(loadOp, axisInfoAnalysis);

test/TritonGPU/loop-pipeline-hopper.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,3 +1007,52 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
10071007
tt.return %17#0 : tensor<128x16xf32, #mma>
10081008
}
10091009
}
1010+
1011+
// -----
1012+
1013+
#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}>
1014+
#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}>
1015+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
1016+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 32]}>
1017+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
1018+
tt.func public @mmav3_fp8_row_major_rhs(%arg0: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1019+
// CHECK-LABEL: mmav3_fp8_row_major_rhs
1020+
// The col-major RHS SMEM encoding in the input, created by accelerate-matmul, should be overwritten by the row-major TMA layout.
1021+
// Note that this "overwriting" makes the program invalid after SWP, since warp_group_dot does not support row-major fp8 RHS.
1022+
// In this case, the TMA load on B should not be pipelined. When this bug is fixed, this test should be rewritten to verify that.
1023+
// CHECK-NOT: order = [0, 1]
1024+
%c128_i32 = arith.constant 128 : i32
1025+
%c64_i32 = arith.constant 64 : i32
1026+
%c0_i32 = arith.constant 0 : i32
1027+
%c1_i32 = arith.constant 1 : i32
1028+
%c127_i32 = arith.constant 127 : i32
1029+
%c63_i32 = arith.constant 63 : i32
1030+
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
1031+
%0 = tt.get_program_id x : i32
1032+
%1 = arith.addi %arg3, %c127_i32 : i32
1033+
%2 = arith.divsi %1, %c128_i32 : i32
1034+
%3 = arith.remsi %0, %2 : i32
1035+
%4 = arith.divsi %0, %2 : i32
1036+
%5 = arith.muli %3, %c128_i32 : i32
1037+
%6 = arith.muli %4, %c64_i32 : i32
1038+
%7 = arith.addi %arg5, %c63_i32 : i32
1039+
%8 = arith.divsi %7, %c64_i32 : i32
1040+
%9 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<128x64xf8E4M3FN>>
1041+
%10 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<64x64xf8E4M3FN>>
1042+
%true = arith.constant true
1043+
%false = arith.constant false
1044+
%11:2 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x64xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 32]}>>, i32) : i32 {
1045+
%14 = tt.experimental_descriptor_load %9[%5, %arg8] : !tt.tensordesc<tensor<128x64xf8E4M3FN>> -> tensor<128x64xf8E4M3FN, #blocked>
1046+
%15 = ttg.local_alloc %14 : (tensor<128x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory>
1047+
%16 = tt.experimental_descriptor_load %10[%arg8, %6] : !tt.tensordesc<tensor<64x64xf8E4M3FN>> -> tensor<64x64xf8E4M3FN, #blocked>
1048+
%17 = ttg.local_alloc %16 : (tensor<64x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<64x64xf8E4M3FN, #shared1, #ttg.shared_memory>
1049+
%18 = ttng.warp_group_dot %15, %17, %arg7 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory> * !ttg.memdesc<64x64xf8E4M3FN, #shared1, #ttg.shared_memory> -> tensor<128x64xf32, #mma>
1050+
%19 = arith.addi %arg8, %c64_i32 : i32
1051+
scf.yield %18, %19 : tensor<128x64xf32, #mma>, i32
1052+
}
1053+
%12 = ttg.convert_layout %11#0 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
1054+
%13 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<128x64xf32>>
1055+
tt.experimental_descriptor_store %13[%5, %6], %12 : !tt.tensordesc<tensor<128x64xf32>>, tensor<128x64xf32, #blocked>
1056+
tt.return
1057+
}
1058+
}

0 commit comments

Comments
 (0)