@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
136136 matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
137137 ConversionPatternRewriter &rewriter) const override {
138138 Location loc = op->getLoc ();
139+ Value initShflValue = adaptor.getValue ();
140+ Type shflType = initShflValue.getType ();
139141 // TODO: Add support for non 32-bit shuffle values.
140- if (adaptor.getValue ().getType ().getIntOrFloatBitWidth () != 32 )
141- return failure ();
142+ if (!shflType.isIntOrFloat () || shflType.getIntOrFloatBitWidth () != 32 )
143+ return rewriter.notifyMatchFailure (
144+ op, " only 32-bit int/float types are supported" );
145+
142146 const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
143147 Value srcLaneId = getLaneId (rewriter, loc, indexBitwidth);
144148
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
175179 Value two = rewriter.create <LLVM::ConstantOp>(loc, int32Type, 2 );
176180 Value dwordAlignedDstLane =
177181 rewriter.create <LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
178- Value initShflValue = adaptor.getValue ();
179- if (adaptor.getValue ().getType ().isF32 ()) {
182+ if (shflType.isF32 ()) {
180183 initShflValue =
181184 rewriter.create <LLVM::BitcastOp>(loc, int32Type, initShflValue);
182185 }
183186 Value shflValue = rewriter.create <ROCDL::DsBpermuteOp>(
184187 loc, int32Type, dwordAlignedDstLane, initShflValue);
185- if (adaptor.getValue ().getType ().isF32 ()) {
186- shflValue = rewriter.create <LLVM::BitcastOp>(
187- loc, adaptor.getValue ().getType (), shflValue);
188+ if (shflType.isF32 ()) {
189+ shflValue = rewriter.create <LLVM::BitcastOp>(loc, shflType, shflValue);
188190 }
189191 rewriter.replaceOp (op, {shflValue, isActiveSrcLane});
190192 return success ();
0 commit comments