Skip to content

Commit 8c418a2

Browse files
committed
PromoteShuffleToDPPPattern
1 parent 084fe21 commit 8c418a2

File tree

2 files changed

+224
-7
lines changed

2 files changed

+224
-7
lines changed

mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp

Lines changed: 146 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
15-
#include "mlir/Dialect/GPU/Transforms/Passes.h"
16-
1714
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1816
#include "mlir/Dialect/Arith/IR/Arith.h"
1917
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
2021
#include "mlir/IR/PatternMatch.h"
2122
#include <optional>
2223

@@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern
8586

8687
int64_t offsetValue = *offset;
8788
if (offsetValue != 16 && offsetValue != 32)
88-
return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
89+
return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");
8990

9091
Location loc = op.getLoc();
9192
Value res = amdgpu::PermlaneSwapOp::create(
@@ -96,13 +97,153 @@ struct PromoteShuffleToPermlanePattern
9697
}
9798
};
9899

100+
static Value getLaneId(RewriterBase &rewriter, Location loc) {
101+
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
102+
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
103+
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
104+
NamedAttribute noundef = rewriter.getNamedAttr(
105+
LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
106+
NamedAttribute lowRange = rewriter.getNamedAttr(
107+
LLVM::LLVMDialect::getRangeAttrName(),
108+
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
109+
APInt(32, 32)));
110+
NamedAttribute highRange = rewriter.getNamedAttr(
111+
LLVM::LLVMDialect::getRangeAttrName(),
112+
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
113+
APInt(32, 64)));
114+
Value mbcntLo = ROCDL::MbcntLoOp::create(
115+
rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
116+
/*res_attrs=*/
117+
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
118+
Value laneId = ROCDL::MbcntHiOp::create(
119+
rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
120+
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
121+
return laneId;
122+
}
123+
124+
/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
125+
/// and offset must be a constant integer in the set {16, 32}.
126+
struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
127+
using OpRewritePattern::OpRewritePattern;
128+
129+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
130+
PatternRewriter &rewriter) const override {
131+
std::optional<int64_t> width = getConstantIntValue(op.getWidth());
132+
if (!width)
133+
return rewriter.notifyMatchFailure(op,
134+
"width must be a constant integer");
135+
int64_t widthValue = *width;
136+
if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
137+
widthValue != 16 && widthValue != 32 && widthValue != 48 &&
138+
widthValue != 64)
139+
return rewriter.notifyMatchFailure(
140+
op, "width must be 4, 8, 12, 16, 32, 48 or 64");
141+
142+
std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
143+
if (!offset)
144+
return rewriter.notifyMatchFailure(op,
145+
"offset must be a constant integer");
146+
147+
int64_t offsetValue = *offset;
148+
Location loc = op.getLoc();
149+
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
150+
151+
amdgpu::DPPPerm kind;
152+
Attribute permAttr = rewriter.getUnitAttr();
153+
Value srcLane;
154+
Value dstLane;
155+
switch (op.getMode()) {
156+
case gpu::ShuffleMode::XOR: {
157+
if (offsetValue != 1 && offsetValue != 2)
158+
return rewriter.notifyMatchFailure(
159+
op, "xor shuffle mode is only supported for offsets of 1 or 2");
160+
kind = amdgpu::DPPPerm::quad_perm;
161+
srcLane = getLaneId(rewriter, loc);
162+
dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLane,
163+
op.getOffset());
164+
165+
if (offsetValue == 1)
166+
permAttr = rewriter.getI32ArrayAttr({1, 0, 3, 2});
167+
else if (offsetValue == 2)
168+
permAttr = rewriter.getI32ArrayAttr({2, 3, 0, 1});
169+
break;
170+
}
171+
case gpu::ShuffleMode::UP: {
172+
if (offsetValue != 1)
173+
return rewriter.notifyMatchFailure(
174+
op, "up shuffle mode is only supported for offset 1");
175+
kind = amdgpu::DPPPerm::wave_shr;
176+
srcLane = getLaneId(rewriter, loc);
177+
dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLane,
178+
op.getOffset());
179+
break;
180+
}
181+
case gpu::ShuffleMode::DOWN: {
182+
if (offsetValue != 1)
183+
return rewriter.notifyMatchFailure(
184+
op, "down shuffle mode is only supported for offset 1");
185+
kind = amdgpu::DPPPerm::wave_shl;
186+
srcLane = getLaneId(rewriter, loc);
187+
dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLane,
188+
op.getOffset());
189+
break;
190+
}
191+
case gpu::ShuffleMode::IDX:
192+
return rewriter.notifyMatchFailure(op,
193+
"idx shuffle mode is not supported");
194+
}
195+
196+
unsigned bankMask = 0xF;
197+
if (widthValue == 4)
198+
bankMask = 0x1;
199+
else if (widthValue == 8)
200+
bankMask = 0x3;
201+
else if (widthValue == 12)
202+
bankMask = 0x7;
203+
204+
unsigned rowMask = 0xF;
205+
if (widthValue == 16)
206+
rowMask = 0x1;
207+
else if (widthValue == 32)
208+
rowMask = 0x3;
209+
else if (widthValue == 48)
210+
rowMask = 0x7;
211+
212+
constexpr bool boundCtrl = false;
213+
214+
Value negwidth =
215+
arith::ConstantIntOp::create(rewriter, loc, int32Type, -widthValue);
216+
Value add =
217+
arith::AddIOp::create(rewriter, loc, int32Type, srcLane, op.getWidth());
218+
Value widthOrZeroIfOutside =
219+
arith::AndIOp::create(rewriter, loc, int32Type, add, negwidth);
220+
Value isActiveSrcLane =
221+
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, dstLane,
222+
widthOrZeroIfOutside);
223+
224+
Value dpp = amdgpu::DPPOp::create(rewriter, loc, op.getResult(0).getType(),
225+
op.getValue(), op.getValue(), kind,
226+
permAttr, rowMask, bankMask, boundCtrl);
227+
Value poison =
228+
LLVM::PoisonOp::create(rewriter, loc, op.getResult(0).getType());
229+
230+
Value selectResult =
231+
arith::SelectOp::create(rewriter, loc, isActiveSrcLane, dpp, poison);
232+
233+
rewriter.replaceOp(op, {selectResult, isActiveSrcLane});
234+
return success();
235+
}
236+
};
237+
99238
} // namespace
100239

101240
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
102241
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
103242
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
104243
/*benefit*/ 1);
244+
patterns.add<PromoteShuffleToDPPPattern>(patterns.getContext(),
245+
/*benefit*/ 2);
105246
if (maybeChipset && *maybeChipset >= kGfx950)
106247
patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
107-
/*benefit*/ 2);
248+
/*benefit*/ 3);
108249
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,13 +735,18 @@ gpu.module @test_module {
735735
}
736736

737737
// CHECK-LABEL: func @gpu_shuffle_promote()
738-
func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
738+
func.func @gpu_shuffle_promote() -> (f32, f32, f32, f32, f32) {
739+
// CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
740+
// CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
739741
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
740742
%arg0 = arith.constant 1.0 : f32
741743
%arg1 = arith.constant 4 : i32
742744
%arg2 = arith.constant 16 : i32
743745
%arg3 = arith.constant 32 : i32
746+
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
744747
%arg4 = arith.constant 64 : i32
748+
// CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
749+
%arg5 = arith.constant 1 : i32
745750
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
746751
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
747752
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
@@ -757,7 +762,78 @@ gpu.module @test_module {
757762
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
758763
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
759764
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
760-
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
765+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
766+
// CHECK: %[[#SUB:]] = llvm.sub %[[#LANE_ID]], %[[#C1]] : i32
767+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
768+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
769+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#SUB]], %[[#AND]] : i32
770+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 312, 15, 15, false : f32
771+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
772+
%shflu, %predu = gpu.shuffle up %arg0, %arg5, %arg4 : f32
773+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
774+
// CHECK: %[[#OP:]] = llvm.add %[[#LANE_ID]], %[[#C1]] : i32
775+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
776+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
777+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#OP]], %[[#AND]] : i32
778+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 304, 15, 15, false : f32
779+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
780+
%shfld, %predd = gpu.shuffle down %arg0, %arg5, %arg4 : f32
781+
func.return %shfl1, %shfl2, %shfl3, %shflu, %shfld : f32, f32, f32, f32, f32
782+
}
783+
784+
// CHECK-LABEL: func @gpu_butterfly_shuffle()
785+
func.func @gpu_butterfly_shuffle() -> (f32, f32, f32, f32, f32, f32) {
786+
// CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
787+
// CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
788+
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
789+
%arg0 = arith.constant 1.0 : f32
790+
// CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
791+
%c1 = arith.constant 1 : i32
792+
// CHECK: %[[#C2:]] = llvm.mlir.constant(2 : i32) : i32
793+
%c2 = arith.constant 2 : i32
794+
%c4 = arith.constant 4 : i32
795+
%c8 = arith.constant 8 : i32
796+
%c16 = arith.constant 16 : i32
797+
%c32 = arith.constant 32 : i32
798+
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
799+
%c64 = arith.constant 64 : i32
800+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
801+
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C1]] : i32
802+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
803+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
804+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
805+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 177, 15, 15, false : f32
806+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
807+
%shfl1, %pred1 = gpu.shuffle xor %arg0, %c1, %c64 : f32
808+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
809+
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C2]] : i32
810+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
811+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
812+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
813+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 78, 15, 15, false : f32
814+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
815+
%shfl2, %pred2 = gpu.shuffle xor %arg0, %c2, %c64 : f32
816+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
817+
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
818+
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
819+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
820+
%shfl3, %pred3 = gpu.shuffle xor %arg0, %c4, %c64 : f32
821+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
822+
// CHECK: %[[#MASK:]] = llvm.mlir.constant(8223 : i32) : i32
823+
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
824+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
825+
%shfl4, %pred4 = gpu.shuffle xor %arg0, %c8, %c64 : f32
826+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
827+
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
828+
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
829+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
830+
%shfl5, %pred5 = gpu.shuffle xor %arg0, %c16, %c64 : f32
831+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
832+
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
833+
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
834+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
835+
%shfl6, %pred6 = gpu.shuffle xor %arg0, %c32, %c64 : f32
836+
func.return %shfl1, %shfl2, %shfl3, %shfl4, %shfl5, %shfl6 : f32, f32, f32, f32, f32, f32
761837
}
762838

763839
// CHECK-LABEL: func @gpu_shuffle_vec

0 commit comments

Comments
 (0)