@@ -262,12 +262,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
262262 .Default ([](auto ) { return std::nullopt ; });
263263 }
264264
265- static std::optional<std::string> getFuncName (gpu::ShuffleOp op) {
266- StringRef baseName = getBaseName (op.getMode ());
267- std::optional<StringRef> typeMangling = getTypeMangling (op.getType (0 ));
265+ static std::optional<std::string> getFuncName (gpu::ShuffleMode mode,
266+ Type type) {
267+ StringRef baseName = getBaseName (mode);
268+ std::optional<StringRef> typeMangling = getTypeMangling (type);
268269 if (!typeMangling)
269270 return std::nullopt ;
270- return llvm::formatv (" _Z{0}{1}{2 }" , baseName.size (), baseName,
271+ return llvm::formatv (" _Z{}{}{ }" , baseName.size (), baseName,
271272 typeMangling.value ());
272273 }
273274
@@ -286,33 +287,70 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
286287 val == getSubgroupSize (op);
287288 }
288289
290+ static Value bitcastOrExtBeforeShuffle (Value oldVal, Location loc,
291+ ConversionPatternRewriter &rewriter) {
292+ return TypeSwitch<Type, Value>(oldVal.getType ())
293+ .Case ([&](BFloat16Type) {
294+ return rewriter.create <LLVM::BitcastOp>(loc, rewriter.getI16Type (),
295+ oldVal);
296+ })
297+ .Case ([&](IntegerType intTy) -> Value {
298+ if (intTy.getWidth () == 1 )
299+ return rewriter.create <LLVM::ZExtOp>(loc, rewriter.getI8Type (),
300+ oldVal);
301+ return oldVal;
302+ })
303+ .Default (oldVal);
304+ }
305+
306+ static Value bitcastOrTruncAfterShuffle (Value oldVal, Type newTy,
307+ Location loc,
308+ ConversionPatternRewriter &rewriter) {
309+ return TypeSwitch<Type, Value>(newTy)
310+ .Case ([&](BFloat16Type) {
311+ return rewriter.create <LLVM::BitcastOp>(loc, newTy, oldVal);
312+ })
313+ .Case ([&](IntegerType intTy) -> Value {
314+ if (intTy.getWidth () == 1 )
315+ return rewriter.create <LLVM::TruncOp>(loc, newTy, oldVal);
316+ return oldVal;
317+ })
318+ .Default (oldVal);
319+ }
320+
289321 LogicalResult
290322 matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
291323 ConversionPatternRewriter &rewriter) const final {
292324 if (!hasValidWidth (op))
293325 return rewriter.notifyMatchFailure (
294326 op, " shuffle width and subgroup size mismatch" );
295327
296- std::optional<std::string> funcName = getFuncName (op);
328+ Location loc = op->getLoc ();
329+ Value inValue =
330+ bitcastOrExtBeforeShuffle (adaptor.getValue (), loc, rewriter);
331+ std::optional<std::string> funcName =
332+ getFuncName (op.getMode (), inValue.getType ());
297333 if (!funcName)
298334 return rewriter.notifyMatchFailure (op, " unsupported value type" );
299335
300336 Operation *moduleOp = op->getParentWithTrait <OpTrait::SymbolTable>();
301337 assert (moduleOp && " Expecting module" );
302- Type valueType = adaptor. getValue () .getType ();
338+ Type valueType = inValue .getType ();
303339 Type offsetType = adaptor.getOffset ().getType ();
304340 Type resultType = valueType;
305341 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn (
306342 moduleOp, funcName.value (), {valueType, offsetType}, resultType,
307343 /* isMemNone=*/ false , /* isConvergent=*/ true );
308344
309- Location loc = op->getLoc ();
310- std::array<Value, 2 > args{adaptor.getValue (), adaptor.getOffset ()};
345+ std::array<Value, 2 > args{inValue, adaptor.getOffset ()};
311346 Value result =
312347 createSPIRVBuiltinCall (loc, rewriter, func, args).getResult ();
348+ Value resultOrConversion =
349+ bitcastOrTruncAfterShuffle (result, op.getType (0 ), loc, rewriter);
350+
313351 Value trueVal =
314352 rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI1Type (), true );
315- rewriter.replaceOp (op, {result , trueVal});
353+ rewriter.replaceOp (op, {resultOrConversion , trueVal});
316354 return success ();
317355 }
318356};
0 commit comments