1111//
1212// ===----------------------------------------------------------------------===//
1313
14- #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
15- #include " mlir/Dialect/GPU/Transforms/Passes.h"
16-
1714#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1816#include " mlir/Dialect/Arith/IR/Arith.h"
1917#include " mlir/Dialect/GPU/IR/GPUDialect.h"
18+ #include " mlir/Dialect/GPU/Transforms/Passes.h"
19+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
20+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
2021#include " mlir/IR/PatternMatch.h"
2122#include < optional>
2223
@@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern
8586
8687 int64_t offsetValue = *offset;
8788 if (offsetValue != 16 && offsetValue != 32 )
88- return rewriter.notifyMatchFailure (op, " offset must be either 15 or 31 " );
89+ return rewriter.notifyMatchFailure (op, " offset must be either 16 or 32 " );
8990
9091 Location loc = op.getLoc ();
9192 Value res = amdgpu::PermlaneSwapOp::create (
@@ -96,13 +97,153 @@ struct PromoteShuffleToPermlanePattern
9697 }
9798};
9899
100+ static Value getLaneId (RewriterBase &rewriter, Location loc) {
101+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
102+ Value zero = arith::ConstantIntOp::create (rewriter, loc, 0 , 32 );
103+ Value minus1 = arith::ConstantIntOp::create (rewriter, loc, -1 , 32 );
104+ NamedAttribute noundef = rewriter.getNamedAttr (
105+ LLVM::LLVMDialect::getNoUndefAttrName (), rewriter.getUnitAttr ());
106+ NamedAttribute lowRange = rewriter.getNamedAttr (
107+ LLVM::LLVMDialect::getRangeAttrName (),
108+ LLVM::ConstantRangeAttr::get (rewriter.getContext (), APInt::getZero (32 ),
109+ APInt (32 , 32 )));
110+ NamedAttribute highRange = rewriter.getNamedAttr (
111+ LLVM::LLVMDialect::getRangeAttrName (),
112+ LLVM::ConstantRangeAttr::get (rewriter.getContext (), APInt::getZero (32 ),
113+ APInt (32 , 64 )));
114+ Value mbcntLo = ROCDL::MbcntLoOp::create (
115+ rewriter, loc, int32Type, minus1, zero, /* arg_attrs=*/ {},
116+ /* res_attrs=*/
117+ rewriter.getArrayAttr (rewriter.getDictionaryAttr ({noundef, lowRange})));
118+ Value laneId = ROCDL::MbcntHiOp::create (
119+ rewriter, loc, int32Type, minus1, mbcntLo, /* arg_attrs=*/ {},
120+ rewriter.getArrayAttr (rewriter.getDictionaryAttr ({noundef, highRange})));
121+ return laneId;
122+ }
123+
124+ // / Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
125+ // / and offset must be a constant integer in the set {16, 32}.
126+ struct PromoteShuffleToDPPPattern : public OpRewritePattern <gpu::ShuffleOp> {
127+ using OpRewritePattern::OpRewritePattern;
128+
129+ LogicalResult matchAndRewrite (gpu::ShuffleOp op,
130+ PatternRewriter &rewriter) const override {
131+ std::optional<int64_t > width = getConstantIntValue (op.getWidth ());
132+ if (!width)
133+ return rewriter.notifyMatchFailure (op,
134+ " width must be a constant integer" );
135+ int64_t widthValue = *width;
136+ if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
137+ widthValue != 16 && widthValue != 32 && widthValue != 48 &&
138+ widthValue != 64 )
139+ return rewriter.notifyMatchFailure (
140+ op, " width must be 4, 8, 12, 16, 32, 48 or 64" );
141+
142+ std::optional<int64_t > offset = getConstantIntValue (op.getOffset ());
143+ if (!offset)
144+ return rewriter.notifyMatchFailure (op,
145+ " offset must be a constant integer" );
146+
147+ int64_t offsetValue = *offset;
148+ Location loc = op.getLoc ();
149+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
150+
151+ amdgpu::DPPPerm kind;
152+ Attribute permAttr = rewriter.getUnitAttr ();
153+ Value srcLane;
154+ Value dstLane;
155+ switch (op.getMode ()) {
156+ case gpu::ShuffleMode::XOR: {
157+ if (offsetValue != 1 && offsetValue != 2 )
158+ return rewriter.notifyMatchFailure (
159+ op, " xor shuffle mode is only supported for offsets of 1 or 2" );
160+ kind = amdgpu::DPPPerm::quad_perm;
161+ srcLane = getLaneId (rewriter, loc);
162+ dstLane = LLVM::XOrOp::create (rewriter, loc, int32Type, srcLane,
163+ op.getOffset ());
164+
165+ if (offsetValue == 1 )
166+ permAttr = rewriter.getI32ArrayAttr ({1 , 0 , 3 , 2 });
167+ else if (offsetValue == 2 )
168+ permAttr = rewriter.getI32ArrayAttr ({2 , 3 , 0 , 1 });
169+ break ;
170+ }
171+ case gpu::ShuffleMode::UP: {
172+ if (offsetValue != 1 )
173+ return rewriter.notifyMatchFailure (
174+ op, " up shuffle mode is only supported for offset 1" );
175+ kind = amdgpu::DPPPerm::wave_shr;
176+ srcLane = getLaneId (rewriter, loc);
177+ dstLane = LLVM::SubOp::create (rewriter, loc, int32Type, srcLane,
178+ op.getOffset ());
179+ break ;
180+ }
181+ case gpu::ShuffleMode::DOWN: {
182+ if (offsetValue != 1 )
183+ return rewriter.notifyMatchFailure (
184+ op, " down shuffle mode is only supported for offset 1" );
185+ kind = amdgpu::DPPPerm::wave_shl;
186+ srcLane = getLaneId (rewriter, loc);
187+ dstLane = LLVM::AddOp::create (rewriter, loc, int32Type, srcLane,
188+ op.getOffset ());
189+ break ;
190+ }
191+ case gpu::ShuffleMode::IDX:
192+ return rewriter.notifyMatchFailure (op,
193+ " idx shuffle mode is not supported" );
194+ }
195+
196+ unsigned bankMask = 0xF ;
197+ if (widthValue == 4 )
198+ bankMask = 0x1 ;
199+ else if (widthValue == 8 )
200+ bankMask = 0x3 ;
201+ else if (widthValue == 12 )
202+ bankMask = 0x7 ;
203+
204+ unsigned rowMask = 0xF ;
205+ if (widthValue == 16 )
206+ rowMask = 0x1 ;
207+ else if (widthValue == 32 )
208+ rowMask = 0x3 ;
209+ else if (widthValue == 48 )
210+ rowMask = 0x7 ;
211+
212+ constexpr bool boundCtrl = false ;
213+
214+ Value negwidth =
215+ arith::ConstantIntOp::create (rewriter, loc, int32Type, -widthValue);
216+ Value add =
217+ arith::AddIOp::create (rewriter, loc, int32Type, srcLane, op.getWidth ());
218+ Value widthOrZeroIfOutside =
219+ arith::AndIOp::create (rewriter, loc, int32Type, add, negwidth);
220+ Value isActiveSrcLane =
221+ arith::CmpIOp::create (rewriter, loc, arith::CmpIPredicate::slt, dstLane,
222+ widthOrZeroIfOutside);
223+
224+ Value dpp = amdgpu::DPPOp::create (rewriter, loc, op.getResult (0 ).getType (),
225+ op.getValue (), op.getValue (), kind,
226+ permAttr, rowMask, bankMask, boundCtrl);
227+ Value poison =
228+ LLVM::PoisonOp::create (rewriter, loc, op.getResult (0 ).getType ());
229+
230+ Value selectResult =
231+ arith::SelectOp::create (rewriter, loc, isActiveSrcLane, dpp, poison);
232+
233+ rewriter.replaceOp (op, {selectResult, isActiveSrcLane});
234+ return success ();
235+ }
236+ };
237+
99238} // namespace
100239
101240void mlir::populateGpuPromoteShuffleToAMDGPUPatterns (
102241 RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
103242 patterns.add <PromoteShuffleToSwizzlePattern>(patterns.getContext (),
104243 /* benefit*/ 1 );
244+ patterns.add <PromoteShuffleToDPPPattern>(patterns.getContext (),
245+ /* benefit*/ 2 );
105246 if (maybeChipset && *maybeChipset >= kGfx950 )
106247 patterns.add <PromoteShuffleToPermlanePattern>(patterns.getContext (),
107- /* benefit*/ 2 );
248+ /* benefit*/ 3 );
108249}
0 commit comments