Skip to content

Commit 318ff2c

Browse files
[AMD] Enable General Swizzling ConvertLayoutOp (triton-lang#7482)
- Enables ConvertLayoutOp general swizzling for AMD - Introduces simple heuristic in OptimizeLDSUsage so we use padding based swizzling only if it is required to reduce LDS consumption - Adds AMD specific ttg->llvm conversion pattern to process forced padding separately from general swizzling --------- Co-authored-by: Alexander Efimov <[email protected]>
1 parent 1793a04 commit 318ff2c

File tree

4 files changed

+89
-11
lines changed

4 files changed

+89
-11
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
278278
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
279279

280280
// Try to use swizzling to implement the conversion
281-
// HACK Remove once AMD tests pass for the swizzling path
282-
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
283-
op, adaptor.getSrc(), rewriter))) {
281+
if (succeeded(
282+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter))) {
284283
return success();
285284
}
286285

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --canonicalize| FileCheck %s
2+
3+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
6+
// CHECK: llvm.mlir.global external @global_smem
7+
tt.func @convert_layout_general_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {
8+
9+
// verify that following convert layout uses general swizzling
10+
11+
// CHECK-NOT: llvm.lshr
12+
13+
%0 = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
14+
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
15+
tt.return
16+
}
17+
}
18+
19+
// -----
20+
21+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
22+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
23+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
24+
// CHECK-LABEL: convert_layout_padding_swizzling
25+
tt.func @convert_layout_padding_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {
26+
27+
// verify that following convert layout uses padded path
28+
// see getVecAddr lambda in transferWithinBlockImpl function
29+
30+
// CHECK-DAG: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
31+
// CHECK-DAG: [[CST_5:%.*]] = llvm.mlir.constant(5 : i32) : i32
32+
// CHECK-DAG: [[OFFSET_0:%.*]] = llvm.lshr {{.*}}, [[CST_5]] : i32
33+
// CHECK: [[OFFSET_1:%.*]] = llvm.shl [[OFFSET_0]], [[CST_0]] : i32
34+
// CHECK: [[OFFSET_2:%.*]] = llvm.add [[OFFSET_1]], {{.*}} : i32
35+
// CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_2]]{{\]}}
36+
37+
%0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
38+
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
39+
tt.return
40+
}
41+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
#include "Analysis/AMDGPUAllocation.h"
12
#include "PatternTritonGPUOpToLLVM.h"
23
#include "Utility.h"
34
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
45
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
56
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
67

8+
using ::mlir::transferWithinBlockPadding;
79
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
810
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
11+
using ::mlir::triton::gpu::ConvertLayoutOp;
912
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1013
using ::mlir::triton::gpu::MemDescType;
1114
using ::triton::gpu::LinearEncodingAttr;
@@ -287,6 +290,36 @@ struct ConvertLayoutOpMFMAToLinearConversion
287290
protected:
288291
const TargetInfoBase &targetInfo;
289292
};
293+
294+
struct ConvertLayoutForcedPadding
295+
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
296+
297+
explicit ConvertLayoutForcedPadding(LLVMTypeConverter &typeConverter,
298+
const TargetInfoBase &targetInfo,
299+
PatternBenefit benefit)
300+
: ConvertOpToLLVMPattern<ConvertLayoutOp>(typeConverter, benefit),
301+
targetInfo(targetInfo) {}
302+
303+
LogicalResult
304+
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
305+
ConversionPatternRewriter &rewriter) const override {
306+
if (!op->hasAttr(mlir::triton::AMD::AttrSharedMemPadded))
307+
return failure();
308+
auto srcType = op.getSrc().getType();
309+
auto dstType = op.getType();
310+
if (!cvtNeedsSharedMemory(srcType, dstType))
311+
return failure();
312+
313+
auto result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
314+
getTypeConverter(), rewriter);
315+
rewriter.replaceOp(op, result);
316+
return success();
317+
}
318+
319+
protected:
320+
const TargetInfoBase &targetInfo;
321+
};
322+
290323
} // namespace
291324

292325
void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns(
@@ -296,4 +329,5 @@ void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns(
296329
benefit);
297330
patterns.add<ConvertLayoutOpMFMAToLinearConversion>(typeConverter, targetInfo,
298331
benefit);
332+
patterns.add<ConvertLayoutForcedPadding>(typeConverter, targetInfo, benefit);
299333
}

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,26 @@ class OptimizeAMDLDSUsage
9494
LDBG("Trying fit " << cvtOp << " into " << targetLDSSize << " bytes");
9595
OpBuilder builder(cvtOp);
9696

97+
auto ctx = builder.getContext();
9798
auto srcType = cvtOp.getSrc().getType();
9899
auto dstType = cvtOp.getType();
99100

101+
if (!cvtOp->hasAttr(triton::AMD::AttrSharedMemPadded)) {
102+
auto emptyAttribute = UnitAttr::get(ctx);
103+
// Padded conversion seems more friendly with this optimization
104+
// use it instead of general swizzling.
105+
cvtOp->setAttr(triton::AMD::AttrSharedMemPadded, emptyAttribute);
106+
// if padded layout drops LDS usage on itself, we are done, return
107+
if (triton::AMD::getConvertLayoutScratchInBytes(
108+
srcType, dstType, /*usePadding*/ true) <= targetLDSSize)
109+
return;
110+
}
111+
100112
auto srcEnc =
101113
cast<triton::gpu::DistributedEncodingTrait>(srcType.getEncoding());
102114
auto dstEnc =
103115
cast<triton::gpu::DistributedEncodingTrait>(dstType.getEncoding());
104116

105-
auto ctx = srcEnc.getContext();
106117
auto rank = srcType.getRank();
107118

108119
unsigned numWarps = triton::gpu::lookupNumWarps(cvtOp);
@@ -244,13 +255,6 @@ class OptimizeAMDLDSUsage
244255
LDSLimit = targetInfo.getSharedMemorySize();
245256
}
246257

247-
auto context = mod.getContext();
248-
auto emptyAttribute = UnitAttr::get(context);
249-
// TODO choose between padded and swizzled memory patterns
250-
mod.walk([emptyAttribute](triton::gpu::ConvertLayoutOp op) -> void {
251-
op->setAttr(mlir::triton::AMD::AttrSharedMemPadded, emptyAttribute);
252-
});
253-
254258
ModuleAllocation allocAnalysis(
255259
mod, mlir::triton::AMD::AMDAllocationAnalysisScratchSizeFn);
256260
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)

0 commit comments

Comments
 (0)