Skip to content

Commit 13f558a

Browse files
Reland "[AMD] Enable General Swizzling ConvertLayoutOp (#7482)" (#4886)
Fixes #4795
2 parents fad8c9b + a6993ad commit 13f558a

File tree

6 files changed

+102
-10
lines changed

6 files changed

+102
-10
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class TargetInfoBase {
9797
virtual bool supportLdMatrix() const { return false; }
9898
virtual bool supportStMatrix() const { return false; }
9999
virtual bool isCuda() const { return false; }
100+
virtual bool isXpu() const { return false; }
100101

101102
// Annotate target specific information to local load operations during
102103
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
279279
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
280280

281281
// Try to use swizzling to implement the conversion
282-
// HACK Remove once AMD tests pass for the swizzling path
283-
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
282+
// HACK Remove once XPU tests pass for the swizzling path
283+
if (!targetInfo.isXpu() && succeeded(transferWithinBlockSwizzling(
284284
op, adaptor.getSrc(), rewriter))) {
285285
return success();
286286
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --cse| FileCheck %s
2+
3+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
6+
// CHECK: llvm.mlir.global external @global_smem
7+
tt.func @convert_layout_general_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {
8+
9+
// verify that following convert layout uses general swizzling path
10+
11+
// CHECK: [[CST_128:%.*]] = llvm.mlir.constant(128 : i32) : i32
12+
13+
// Part of offset computation generated by applyLinearLayout function
14+
// CHECK: [[SEL:%.*]]= llvm.select {{.*}}, {{.*}}, [[CST_128]]
15+
// CHECK: [[OFFSET_0:%.*]] = llvm.xor {{.*}}, [[SEL]]
16+
// CHECK: [[OFFSET_1:%.*]] = llvm.xor {{.*}}, [[OFFSET_0]] : i32
17+
18+
// Part of offset computation generated by lowerLdSt function after applyLinearLayout
19+
// CHECK: [[OFFSET_2:%.*]] = llvm.xor [[OFFSET_1]], {{.*}} : i32
20+
// CHECK: [[OFFSET_3:%.*]] = llvm.xor [[OFFSET_2]], {{.*}} : i32
21+
// CHECK: [[OFFSET_4:%.*]] = llvm.add [[OFFSET_3]], {{.*}} : i32
22+
// CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_4]]{{\]}}
23+
24+
%0 = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
25+
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
26+
tt.return
27+
}
28+
}
29+
30+
// -----
31+
32+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
33+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
34+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
35+
// CHECK-LABEL: convert_layout_padding_swizzling
36+
tt.func @convert_layout_padding_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {
37+
38+
// verify that following convert layout uses padded path
39+
// see getVecAddr lambda in transferWithinBlockImpl function
40+
41+
// CHECK-DAG: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
42+
// CHECK-DAG: [[CST_5:%.*]] = llvm.mlir.constant(5 : i32) : i32
43+
// CHECK-DAG: [[OFFSET_0:%.*]] = llvm.lshr {{.*}}, [[CST_5]] : i32
44+
// CHECK: [[OFFSET_1:%.*]] = llvm.shl [[OFFSET_0]], [[CST_0]] : i32
45+
// CHECK: [[OFFSET_2:%.*]] = llvm.add [[OFFSET_1]], {{.*}} : i32
46+
// CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_2]]{{\]}}
47+
48+
%0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
49+
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
50+
tt.return
51+
}
52+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Analysis/AMDGPUAllocation.h"
12
#include "PatternTritonGPUOpToLLVM.h"
23
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
@@ -110,6 +111,35 @@ class ConvertLayoutOpMFMAToLinearConversion
110111
return success();
111112
}
112113

114+
private:
115+
const TargetInfoBase &targetInfo;
116+
};
117+
118+
class ConvertLayoutForcedPadding
119+
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
120+
public:
121+
ConvertLayoutForcedPadding(LLVMTypeConverter &typeConverter,
122+
const TargetInfoBase &targetInfo,
123+
PatternBenefit benefit)
124+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
125+
}
126+
127+
LogicalResult
128+
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
129+
ConversionPatternRewriter &rewriter) const override {
130+
if (!op->hasAttr(mlir::triton::AMD::AttrSharedMemPadded))
131+
return failure();
132+
auto srcType = op.getSrc().getType();
133+
auto dstType = op.getType();
134+
if (!cvtNeedsSharedMemory(srcType, dstType))
135+
return failure();
136+
137+
auto result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
138+
getTypeConverter(), rewriter);
139+
rewriter.replaceOp(op, result);
140+
return success();
141+
}
142+
113143
private:
114144
const TargetInfoBase &targetInfo;
115145
};
@@ -120,4 +150,7 @@ void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns(
120150
RewritePatternSet &patterns, PatternBenefit benefit) {
121151
patterns.add<ConvertLayoutOpMFMAToLinearConversion>(typeConverter, targetInfo,
122152
benefit);
153+
patterns.add<ConvertLayoutForcedPadding>(typeConverter, targetInfo, benefit);
154+
// No need to convert when ForcedSwizzling as it's already the default
155+
// lowering
123156
}

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,26 @@ class OptimizeAMDLDSUsage
9494
LDBG("Trying fit " << cvtOp << " into " << targetLDSSize << " bytes");
9595
OpBuilder builder(cvtOp);
9696

97+
auto ctx = builder.getContext();
9798
auto srcType = cvtOp.getSrc().getType();
9899
auto dstType = cvtOp.getType();
99100

101+
if (!cvtOp->hasAttr(triton::AMD::AttrSharedMemPadded)) {
102+
auto emptyAttribute = UnitAttr::get(ctx);
103+
// Padded conversion seems more friendly with this optimization
104+
// use it instead of general swizzling.
105+
cvtOp->setAttr(triton::AMD::AttrSharedMemPadded, emptyAttribute);
106+
// if padded layout drops LDS usage on itself, we are done, return
107+
if (triton::AMD::getConvertLayoutScratchInBytes(
108+
srcType, dstType, /*usePadding*/ true) <= targetLDSSize)
109+
return;
110+
}
111+
100112
auto srcEnc =
101113
cast<triton::gpu::DistributedEncodingTrait>(srcType.getEncoding());
102114
auto dstEnc =
103115
cast<triton::gpu::DistributedEncodingTrait>(dstType.getEncoding());
104116

105-
auto ctx = srcEnc.getContext();
106117
auto rank = srcType.getRank();
107118

108119
unsigned numWarps = triton::gpu::lookupNumWarps(cvtOp);
@@ -244,13 +255,6 @@ class OptimizeAMDLDSUsage
244255
LDSLimit = targetInfo.getSharedMemorySize();
245256
}
246257

247-
auto context = mod.getContext();
248-
auto emptyAttribute = UnitAttr::get(context);
249-
// TODO choose between padded and swizzled memory patterns
250-
mod.walk([emptyAttribute](triton::gpu::ConvertLayoutOp op) -> void {
251-
op->setAttr(mlir::triton::AMD::AttrSharedMemPadded, emptyAttribute);
252-
});
253-
254258
ModuleAllocation allocAnalysis(
255259
mod, mlir::triton::AMD::AMDAllocationAnalysisScratchSizeFn);
256260
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
7676
StringRef name, StringRef value,
7777
unsigned addressSpace) const;
7878

79+
bool isXpu() const override { return true; }
80+
7981
protected:
8082
virtual bool isSupportedWarpReduceOp(Operation *op, unsigned numLanesToReduce,
8183
unsigned warpSize) const = 0;

0 commit comments

Comments
 (0)