Skip to content

Commit 10a9526

Browse files
Revert "Revert "[AMD] Enable General Swizzling ConvertLayoutOp (#7482)""
This reverts commit 8a721f3.
1 parent fad8c9b commit 10a9526

File tree

4 files changed

+102
-11
lines changed

4 files changed

+102
-11
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
279279
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
280280

281281
// Try to use swizzling to implement the conversion
282-
// HACK Remove once AMD tests pass for the swizzling path
283-
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
284-
op, adaptor.getSrc(), rewriter))) {
282+
if (succeeded(
283+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter))) {
285284
return success();
286285
}
287286

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --cse| FileCheck %s
2+
3+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
6+
// CHECK: llvm.mlir.global external @global_smem
7+
tt.func @convert_layout_general_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {
8+
9+
// verify that following convert layout uses general swizzling path
10+
11+
// CHECK: [[CST_128:%.*]] = llvm.mlir.constant(128 : i32) : i32
12+
13+
// Part of offset computation generated by applyLinearLayout function
14+
// CHECK: [[SEL:%.*]]= llvm.select {{.*}}, {{.*}}, [[CST_128]]
15+
// CHECK: [[OFFSET_0:%.*]] = llvm.xor {{.*}}, [[SEL]]
16+
// CHECK: [[OFFSET_1:%.*]] = llvm.xor {{.*}}, [[OFFSET_0]] : i32
17+
18+
// Part of offset computation generated by lowerLdSt function after applyLinearLayout
19+
// CHECK: [[OFFSET_2:%.*]] = llvm.xor [[OFFSET_1]], {{.*}} : i32
20+
// CHECK: [[OFFSET_3:%.*]] = llvm.xor [[OFFSET_2]], {{.*}} : i32
21+
// CHECK: [[OFFSET_4:%.*]] = llvm.add [[OFFSET_3]], {{.*}} : i32
22+
// CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_4]]{{\]}}
23+
24+
%0 = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
25+
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
26+
tt.return
27+
}
28+
}
29+
30+
// -----
31+
32+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
33+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
34+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
35+
// CHECK-LABEL: convert_layout_padding_swizzling
36+
tt.func @convert_layout_padding_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {
37+
38+
// verify that following convert layout uses padded path
39+
// see getVecAddr lambda in transferWithinBlockImpl function
40+
41+
// CHECK-DAG: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
42+
// CHECK-DAG: [[CST_5:%.*]] = llvm.mlir.constant(5 : i32) : i32
43+
// CHECK-DAG: [[OFFSET_0:%.*]] = llvm.lshr {{.*}}, [[CST_5]] : i32
44+
// CHECK: [[OFFSET_1:%.*]] = llvm.shl [[OFFSET_0]], [[CST_0]] : i32
45+
// CHECK: [[OFFSET_2:%.*]] = llvm.add [[OFFSET_1]], {{.*}} : i32
46+
// CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_2]]{{\]}}
47+
48+
%0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
49+
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
50+
tt.return
51+
}
52+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
#include "Analysis/AMDGPUAllocation.h"
12
#include "PatternTritonGPUOpToLLVM.h"
23
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
45

6+
using ::mlir::transferWithinBlockPadding;
57
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
8+
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
69
using ::mlir::triton::gpu::ConvertLayoutOp;
10+
using ::mlir::triton::gpu::DotOperandEncodingAttr;
11+
using ::mlir::triton::gpu::MemDescType;
712
using ::triton::gpu::LinearEncodingAttr;
813

914
namespace {
@@ -113,11 +118,42 @@ class ConvertLayoutOpMFMAToLinearConversion
113118
private:
114119
const TargetInfoBase &targetInfo;
115120
};
121+
122+
struct ConvertLayoutForcedPadding
123+
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
124+
125+
explicit ConvertLayoutForcedPadding(LLVMTypeConverter &typeConverter,
126+
const TargetInfoBase &targetInfo,
127+
PatternBenefit benefit)
128+
: ConvertOpToLLVMPattern<ConvertLayoutOp>(typeConverter, benefit),
129+
targetInfo(targetInfo) {}
130+
131+
LogicalResult
132+
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
133+
ConversionPatternRewriter &rewriter) const override {
134+
if (!op->hasAttr(mlir::triton::AMD::AttrSharedMemPadded))
135+
return failure();
136+
auto srcType = op.getSrc().getType();
137+
auto dstType = op.getType();
138+
if (!cvtNeedsSharedMemory(srcType, dstType))
139+
return failure();
140+
141+
auto result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
142+
getTypeConverter(), rewriter);
143+
rewriter.replaceOp(op, result);
144+
return success();
145+
}
146+
147+
protected:
148+
const TargetInfoBase &targetInfo;
149+
};
150+
116151
} // namespace
117152

118153
void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns(
119154
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
120155
RewritePatternSet &patterns, PatternBenefit benefit) {
121156
patterns.add<ConvertLayoutOpMFMAToLinearConversion>(typeConverter, targetInfo,
122157
benefit);
158+
patterns.add<ConvertLayoutForcedPadding>(typeConverter, targetInfo, benefit);
123159
}

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,26 @@ class OptimizeAMDLDSUsage
9494
LDBG("Trying fit " << cvtOp << " into " << targetLDSSize << " bytes");
9595
OpBuilder builder(cvtOp);
9696

97+
auto ctx = builder.getContext();
9798
auto srcType = cvtOp.getSrc().getType();
9899
auto dstType = cvtOp.getType();
99100

101+
if (!cvtOp->hasAttr(triton::AMD::AttrSharedMemPadded)) {
102+
auto emptyAttribute = UnitAttr::get(ctx);
103+
// Padded conversion seems more friendly with this optimization
104+
// use it instead of general swizzling.
105+
cvtOp->setAttr(triton::AMD::AttrSharedMemPadded, emptyAttribute);
106+
// if padded layout drops LDS usage on itself, we are done, return
107+
if (triton::AMD::getConvertLayoutScratchInBytes(
108+
srcType, dstType, /*usePadding*/ true) <= targetLDSSize)
109+
return;
110+
}
111+
100112
auto srcEnc =
101113
cast<triton::gpu::DistributedEncodingTrait>(srcType.getEncoding());
102114
auto dstEnc =
103115
cast<triton::gpu::DistributedEncodingTrait>(dstType.getEncoding());
104116

105-
auto ctx = srcEnc.getContext();
106117
auto rank = srcType.getRank();
107118

108119
unsigned numWarps = triton::gpu::lookupNumWarps(cvtOp);
@@ -244,13 +255,6 @@ class OptimizeAMDLDSUsage
244255
LDSLimit = targetInfo.getSharedMemorySize();
245256
}
246257

247-
auto context = mod.getContext();
248-
auto emptyAttribute = UnitAttr::get(context);
249-
// TODO choose between padded and swizzled memory patterns
250-
mod.walk([emptyAttribute](triton::gpu::ConvertLayoutOp op) -> void {
251-
op->setAttr(mlir::triton::AMD::AttrSharedMemPadded, emptyAttribute);
252-
});
253-
254258
ModuleAllocation allocAnalysis(
255259
mod, mlir::triton::AMD::AMDAllocationAnalysisScratchSizeFn);
256260
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)

0 commit comments

Comments
 (0)