Skip to content

Commit 65a9065

Browse files
committed
PromoteShuffleToDPPPattern
1 parent a24a754 commit 65a9065

File tree

2 files changed

+228
-7
lines changed

2 files changed

+228
-7
lines changed

mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
15-
#include "mlir/Dialect/GPU/Transforms/Passes.h"
16-
1714
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1816
#include "mlir/Dialect/Arith/IR/Arith.h"
1917
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
2021
#include "mlir/IR/PatternMatch.h"
2122
#include <optional>
2223

@@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern
8586

8687
int64_t offsetValue = *offset;
8788
if (offsetValue != 16 && offsetValue != 32)
88-
return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
89+
return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");
8990

9091
Location loc = op.getLoc();
9192
Value res = amdgpu::PermlaneSwapOp::create(
@@ -96,13 +97,151 @@ struct PromoteShuffleToPermlanePattern
9697
}
9798
};
9899

100+
static Value getLaneId(RewriterBase &rewriter, Location loc) {
101+
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
102+
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
103+
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
104+
NamedAttribute noundef = {LLVM::LLVMDialect::getNoUndefAttrName(),
105+
rewriter.getUnitAttr()};
106+
NamedAttribute lowRange = {LLVM::LLVMDialect::getRangeAttrName(),
107+
LLVM::ConstantRangeAttr::get(rewriter.getContext(),
108+
APInt::getZero(32),
109+
APInt(32, 32))};
110+
NamedAttribute highRange = {
111+
LLVM::LLVMDialect::getRangeAttrName(),
112+
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
113+
APInt(32, 64))};
114+
Value mbcntLo = ROCDL::MbcntLoOp::create(
115+
rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
116+
/*res_attrs=*/
117+
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
118+
Value laneId = ROCDL::MbcntHiOp::create(
119+
rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
120+
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
121+
return laneId;
122+
}
123+
124+
/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
125+
/// and offset must be a constant integer in the set {16, 32}.
126+
struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
127+
using OpRewritePattern::OpRewritePattern;
128+
129+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
130+
PatternRewriter &rewriter) const override {
131+
std::optional<int64_t> width = getConstantIntValue(op.getWidth());
132+
if (!width)
133+
return rewriter.notifyMatchFailure(op,
134+
"width must be a constant integer");
135+
int64_t widthValue = *width;
136+
if (!llvm::is_contained({4, 8, 12, 16, 32, 48, 64}, widthValue))
137+
return rewriter.notifyMatchFailure(
138+
op, "width must be 4, 8, 12, 16, 32, 48 or 64");
139+
140+
std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
141+
if (!offset)
142+
return rewriter.notifyMatchFailure(op,
143+
"offset must be a constant integer");
144+
145+
int64_t offsetValue = *offset;
146+
Location loc = op.getLoc();
147+
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
148+
149+
amdgpu::DPPPerm kind;
150+
Attribute permAttr = rewriter.getUnitAttr();
151+
Value srcLane;
152+
Value dstLane;
153+
switch (op.getMode()) {
154+
case gpu::ShuffleMode::XOR: {
155+
if (offsetValue != 1 && offsetValue != 2)
156+
return rewriter.notifyMatchFailure(
157+
op, "xor shuffle mode is only supported for offsets of 1 or 2");
158+
kind = amdgpu::DPPPerm::quad_perm;
159+
srcLane = getLaneId(rewriter, loc);
160+
dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLane,
161+
op.getOffset());
162+
163+
if (offsetValue == 1)
164+
permAttr = rewriter.getI32ArrayAttr({1, 0, 3, 2});
165+
else if (offsetValue == 2)
166+
permAttr = rewriter.getI32ArrayAttr({2, 3, 0, 1});
167+
break;
168+
}
169+
case gpu::ShuffleMode::UP: {
170+
if (offsetValue != 1)
171+
return rewriter.notifyMatchFailure(
172+
op, "up shuffle mode is only supported for offset 1");
173+
kind = amdgpu::DPPPerm::wave_shr;
174+
srcLane = getLaneId(rewriter, loc);
175+
dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLane,
176+
op.getOffset());
177+
break;
178+
}
179+
case gpu::ShuffleMode::DOWN: {
180+
if (offsetValue != 1)
181+
return rewriter.notifyMatchFailure(
182+
op, "down shuffle mode is only supported for offset 1");
183+
kind = amdgpu::DPPPerm::wave_shl;
184+
srcLane = getLaneId(rewriter, loc);
185+
dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLane,
186+
op.getOffset());
187+
break;
188+
}
189+
case gpu::ShuffleMode::IDX:
190+
return rewriter.notifyMatchFailure(op,
191+
"idx shuffle mode is not supported");
192+
}
193+
194+
unsigned bankMask = 0xF;
195+
if (widthValue == 4)
196+
bankMask = 0x1;
197+
else if (widthValue == 8)
198+
bankMask = 0x3;
199+
else if (widthValue == 12)
200+
bankMask = 0x7;
201+
202+
unsigned rowMask = 0xF;
203+
if (widthValue == 16)
204+
rowMask = 0x1;
205+
else if (widthValue == 32)
206+
rowMask = 0x3;
207+
else if (widthValue == 48)
208+
rowMask = 0x7;
209+
210+
constexpr bool boundCtrl = false;
211+
212+
Value negwidth =
213+
arith::ConstantIntOp::create(rewriter, loc, int32Type, -widthValue);
214+
Value add =
215+
arith::AddIOp::create(rewriter, loc, int32Type, srcLane, op.getWidth());
216+
Value widthOrZeroIfOutside =
217+
arith::AndIOp::create(rewriter, loc, int32Type, add, negwidth);
218+
Value isActiveSrcLane =
219+
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, dstLane,
220+
widthOrZeroIfOutside);
221+
222+
Value dpp = amdgpu::DPPOp::create(rewriter, loc, op.getResult(0).getType(),
223+
op.getValue(), op.getValue(), kind,
224+
permAttr, rowMask, bankMask, boundCtrl);
225+
Value poison =
226+
LLVM::PoisonOp::create(rewriter, loc, op.getResult(0).getType());
227+
228+
Value selectResult =
229+
arith::SelectOp::create(rewriter, loc, isActiveSrcLane, dpp, poison);
230+
231+
rewriter.replaceOp(op, {selectResult, isActiveSrcLane});
232+
return success();
233+
}
234+
};
235+
99236
} // namespace
100237

101238
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
102239
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
103240
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
104241
/*benefit*/ 1);
242+
patterns.add<PromoteShuffleToDPPPattern>(patterns.getContext(),
243+
/*benefit*/ 2);
105244
if (maybeChipset && *maybeChipset >= kGfx950)
106245
patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
107-
/*benefit*/ 2);
246+
/*benefit*/ 3);
108247
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,13 +735,18 @@ gpu.module @test_module {
735735
}
736736

737737
// CHECK-LABEL: func @gpu_shuffle_promote()
738-
func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
738+
func.func @gpu_shuffle_promote() -> (f32, f32, f32, f32, f32) {
739+
// CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
740+
// CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
739741
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
740742
%arg0 = arith.constant 1.0 : f32
741743
%arg1 = arith.constant 4 : i32
742744
%arg2 = arith.constant 16 : i32
743745
%arg3 = arith.constant 32 : i32
746+
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
744747
%arg4 = arith.constant 64 : i32
748+
// CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
749+
%arg5 = arith.constant 1 : i32
745750
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
746751
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
747752
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
@@ -763,7 +768,84 @@ gpu.module @test_module {
763768
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
764769
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
765770
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
766-
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
771+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
772+
// CHECK: %[[#SUB:]] = llvm.sub %[[#LANE_ID]], %[[#C1]] : i32
773+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
774+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
775+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#SUB]], %[[#AND]] : i32
776+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 312, 15, 15, false : f32
777+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
778+
%shflu, %predu = gpu.shuffle up %arg0, %arg5, %arg4 : f32
779+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
780+
// CHECK: %[[#OP:]] = llvm.add %[[#LANE_ID]], %[[#C1]] : i32
781+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
782+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
783+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#OP]], %[[#AND]] : i32
784+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 304, 15, 15, false : f32
785+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
786+
%shfld, %predd = gpu.shuffle down %arg0, %arg5, %arg4 : f32
787+
func.return %shfl1, %shfl2, %shfl3, %shflu, %shfld : f32, f32, f32, f32, f32
788+
}
789+
790+
// CHECK-LABEL: func @gpu_butterfly_shuffle()
791+
func.func @gpu_butterfly_shuffle() -> (f32, f32, f32, f32, f32, f32) {
792+
// CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
793+
// CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
794+
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
795+
%arg0 = arith.constant 1.0 : f32
796+
// CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
797+
%c1 = arith.constant 1 : i32
798+
// CHECK: %[[#C2:]] = llvm.mlir.constant(2 : i32) : i32
799+
%c2 = arith.constant 2 : i32
800+
%c4 = arith.constant 4 : i32
801+
%c8 = arith.constant 8 : i32
802+
%c16 = arith.constant 16 : i32
803+
%c32 = arith.constant 32 : i32
804+
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
805+
%c64 = arith.constant 64 : i32
806+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
807+
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C1]] : i32
808+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
809+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
810+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
811+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 177, 15, 15, false : f32
812+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
813+
%shfl1, %pred1 = gpu.shuffle xor %arg0, %c1, %c64 : f32
814+
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
815+
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C2]] : i32
816+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
817+
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
818+
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
819+
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 78, 15, 15, false : f32
820+
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
821+
%shfl2, %pred2 = gpu.shuffle xor %arg0, %c2, %c64 : f32
822+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
823+
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
824+
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
825+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
826+
%shfl3, %pred3 = gpu.shuffle xor %arg0, %c4, %c64 : f32
827+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
828+
// CHECK: %[[#MASK:]] = llvm.mlir.constant(8223 : i32) : i32
829+
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
830+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
831+
%shfl4, %pred4 = gpu.shuffle xor %arg0, %c8, %c64 : f32
832+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
833+
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
834+
// CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
835+
// CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
836+
// CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
837+
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
838+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
839+
%shfl5, %pred5 = gpu.shuffle xor %arg0, %c16, %c64 : f32
840+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
841+
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
842+
// CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
843+
// CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
844+
// CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
845+
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
846+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
847+
%shfl6, %pred6 = gpu.shuffle xor %arg0, %c32, %c64 : f32
848+
func.return %shfl1, %shfl2, %shfl3, %shfl4, %shfl5, %shfl6 : f32, f32, f32, f32, f32, f32
767849
}
768850

769851
// CHECK-LABEL: func @gpu_shuffle_vec

0 commit comments

Comments
 (0)