Skip to content

Commit 1143c03

Browse files
authored
[BACKEND] Fix layout picked during TMA store pipelining (#6978)
We were picking a layout inconsistent with the descriptor. Instead we should decide during pipelining and use what was decided for the descriptor.
1 parent c79e5d6 commit 1143c03

File tree

4 files changed

+18
-18
lines changed

4 files changed

+18
-18
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,11 @@ def NVMMASharedEncodingAttr :
437437
} else {
438438
swizzlingByteWidth = 0;
439439
}
440-
if (shapePerCTA.size() < 2 || shapePerCTA[order[1]] < 8) {
440+
int flattenOutterDim = 1;
441+
for (int i = 1; i < shapePerCTA.size(); i++) {
442+
flattenOutterDim *= shapePerCTA[order[i]];
443+
}
444+
if (shapePerCTA.size() < 2 || flattenOutterDim < 8) {
441445
swizzlingByteWidth = 0;
442446
}
443447
bool transposed = order[0] == 0;

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,10 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
220220
auto CTALayout = getCTALayout(argType.getEncoding());
221221
// No swizzling for scale for now
222222
auto newLayout = NVMMASharedEncodingAttr::get(
223-
argType.getContext(), argType.getShape(), newOrder, CTALayout,
224-
argType.getElementType(), false);
223+
argType.getContext(), /*swizzlingByteWidth=*/0,
224+
/*transposed=*/false,
225+
/*elementBitWidth=*/argType.getElementType().getIntOrFloatBitWidth(),
226+
/*fp4Padded=*/false, CTALayout);
225227
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
226228
newLayout, SharedMemorySpace);
227229
rewriter.setInsertionPointAfterValue(arg);

lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,8 @@ static SmallVector<TMAStore> getTMAStores(scf::ForOp forOp) {
3333
static Value createAlloc(scf::ForOp &forOp, const TMAStore &store) {
3434
OpBuilder builder(forOp);
3535
RankedTensorType ty = store.src.getType();
36-
// Is this one correct or should it always be [2, 1, 0]?
37-
auto order = triton::gpu::getOrderForMemory(ty);
38-
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
39-
Attribute encoding = ttg::SwizzledSharedEncodingAttr::get(
40-
ty.getContext(), 1, 1, 1, order, ctaLayout);
41-
if (ty.getRank() > 1) {
42-
encoding = ttg::NVMMASharedEncodingAttr::get(
43-
ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType(),
44-
/*fp4Padded*/ false);
45-
}
36+
auto encoding =
37+
triton::nvidia_gpu::getEncodingFromDescriptor(store.op, ty, store.desc);
4638
Attribute sharedMemorySpace =
4739
triton::gpu::SharedMemorySpaceAttr::get(ty.getContext());
4840
Type memdescType =
@@ -58,7 +50,6 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store,
5850
OpBuilder builder(store.op);
5951
Location loc = store.op->getLoc();
6052
RankedTensorType ty = store.src.getType();
61-
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
6253

6354
// Put wait before the local_store make the store truly async. We know
6455
// that we are the only user of the CopyLocalToGlobal.

test/TritonGPU/loop-pipeline-hopper.mlir

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,20 +448,23 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
448448

449449
// -----
450450
// Test pipelining of descriptor_store
451-
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
452-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
451+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
452+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
453453
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
454+
// CHECK: #[[$SHARED:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
454455
// CHECK-LABEL: tma_store_pipeline
455-
tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc<tensor<1xf32, #shared>>, %arg2: i32, %arg3: i32) attributes {noinline = false} {
456+
tt.func public @tma_store_pipeline(%arg0: tensor<128x128xf32, #blocked>, %arg1: !tt.tensordesc<tensor<128x128xf32, #shared>>, %arg2: i32, %arg3: i32) attributes {noinline = false} {
456457
%c0_i32 = arith.constant 0 : i32
458+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #[[$SHARED]], #smem, mutable>
459+
// CHECK: scf.for
457460
scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 {
458461
%1 = arith.divsi %arg4, %arg2 : i32
459462
// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32}
460463
// CHECK-NEXT: ttg.local_store
461464
// CHECK-NEXT: ttng.fence_async_shared
462465
// CHECK-NEXT: ttng.tensor_desc_to_tma_ptr
463466
// CHECK-NEXT: ttng.async_tma_copy_local_to_global
464-
tt.descriptor_store %arg1[%1], %arg0 : !tt.tensordesc<tensor<1xf32, #shared>>, tensor<1xf32, #blocked>
467+
tt.descriptor_store %arg1[%1, %1], %arg0 : !tt.tensordesc<tensor<128x128xf32, #shared>>, tensor<128x128xf32, #blocked>
465468
}
466469
tt.return
467470
}

0 commit comments

Comments
 (0)