@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
262262 .Default ([](auto ) { return std::nullopt ; });
263263 }
264264
265- static std::optional<std::string> getFuncName (gpu::ShuffleOp op) {
266- StringRef baseName = getBaseName (op.getMode ());
267- std::optional<StringRef> typeMangling = getTypeMangling (op.getType (0 ));
265+ static std::optional<std::string> getFuncName (gpu::ShuffleMode mode,
266+ Type type) {
267+ StringRef baseName = getBaseName (mode);
268+ std::optional<StringRef> typeMangling = getTypeMangling (type);
268269 if (!typeMangling)
269270 return std::nullopt ;
270271 return llvm::formatv (" _Z{0}{1}{2}" , baseName.size (), baseName,
271272 typeMangling.value ());
272273 }
273274
275+ static std::optional<std::string> getFuncName (gpu::ShuffleOp op) {
276+ return getFuncName (op.getMode (), op.getType (0 ));
277+ }
278+
274279 // / Get the subgroup size from the target or return a default.
275280 static std::optional<int > getSubgroupSize (Operation *op) {
276281 auto parentFunc = op->getParentOfType <LLVM::LLVMFuncOp>();
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
286291 val == getSubgroupSize (op);
287292 }
288293
294+ static bool needsBitCastOrExt (gpu::ShuffleOp op) {
295+ Type type = op.getType (0 );
296+ return isa<BFloat16Type>(type) || type.isInteger (1 );
297+ }
298+
299+ static Type getBitCastOrExtTy (Type oldTy,
300+ ConversionPatternRewriter &rewriter) {
301+ return TypeSwitch<Type, Type>(oldTy)
302+ .Case <BFloat16Type>([&](auto ) { return rewriter.getIntegerType (16 ); })
303+ .Case <IntegerType>([&](auto intTy) -> Type {
304+ if (intTy.getWidth () == 1 )
305+ return rewriter.getIntegerType (8 );
306+ return Type{};
307+ })
308+ .Default ([](auto ) { return Type{}; });
309+ }
310+
311+ static Value doBitcastOrExt (Value oldVal, Type newTy, Location loc,
312+ ConversionPatternRewriter &rewriter) {
313+ return TypeSwitch<Type, Value>(oldVal.getType ())
314+ .Case <BFloat16Type>([&](auto ) {
315+ return rewriter.create <LLVM::BitcastOp>(loc, newTy, oldVal);
316+ })
317+ .Case <IntegerType>([&](auto intTy) -> Value {
318+ if (intTy.getWidth () == 1 )
319+ return rewriter.create <LLVM::ZExtOp>(loc, newTy, oldVal);
320+ return Value{};
321+ })
322+ .Default ([](auto ) { return Value{}; });
323+ }
324+
325+ static Value doBitcastOrTrunc (Value oldVal, Type newTy, Location loc,
326+ ConversionPatternRewriter &rewriter) {
327+ return TypeSwitch<Type, Value>(newTy)
328+ .Case <BFloat16Type>([&](auto ) {
329+ return rewriter.create <LLVM::BitcastOp>(loc, newTy, oldVal);
330+ })
331+ .Case <IntegerType>([&](auto intTy) -> Value {
332+ if (intTy.getWidth () == 1 )
333+ return rewriter.create <LLVM::TruncOp>(loc, newTy, oldVal);
334+ return Value{};
335+ })
336+ .Default ([](auto ) { return Value{}; });
337+ }
338+
289339 LogicalResult
290340 matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
291341 ConversionPatternRewriter &rewriter) const final {
292342 if (!hasValidWidth (op))
293343 return rewriter.notifyMatchFailure (
294344 op, " shuffle width and subgroup size mismatch" );
295345
296- std::optional<std::string> funcName = getFuncName (op);
346+ Location loc = op->getLoc ();
347+ Type bitcastOrExtDestTy = getBitCastOrExtTy (op.getType (0 ), rewriter);
348+ std::optional<std::string> funcName;
349+ Value inValue;
350+ if (bitcastOrExtDestTy) {
351+ Value newVal =
352+ doBitcastOrExt (adaptor.getValue (), bitcastOrExtDestTy, loc, rewriter);
353+ assert (newVal && " Unhandled op type in bitcastorext" );
354+ funcName = getFuncName (op.getMode (), bitcastOrExtDestTy);
355+ inValue = newVal;
356+ } else {
357+ funcName = getFuncName (op);
358+ inValue = adaptor.getValue ();
359+ }
297360 if (!funcName)
298361 return rewriter.notifyMatchFailure (op, " unsupported value type" );
299362
300363 Operation *moduleOp = op->getParentWithTrait <OpTrait::SymbolTable>();
301364 assert (moduleOp && " Expecting module" );
302- Type valueType = adaptor. getValue () .getType ();
365+ Type valueType = inValue .getType ();
303366 Type offsetType = adaptor.getOffset ().getType ();
304367 Type resultType = valueType;
305368 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn (
306369 moduleOp, funcName.value (), {valueType, offsetType}, resultType,
307370 /* isMemNone=*/ false , /* isConvergent=*/ true );
308371
309- Location loc = op->getLoc ();
310- std::array<Value, 2 > args{adaptor.getValue (), adaptor.getOffset ()};
372+ std::array<Value, 2 > args{inValue, adaptor.getOffset ()};
311373 Value result =
312374 createSPIRVBuiltinCall (loc, rewriter, func, args).getResult ();
375+ if (bitcastOrExtDestTy) {
376+ Value newVal =
377+ doBitcastOrTrunc (result, adaptor.getValue ().getType (), loc, rewriter);
378+ assert (newVal && " Unhandled op type in bitcastortrunc" );
379+ result = newVal;
380+ }
381+
313382 Value trueVal =
314383 rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI1Type (), true );
315384 rewriter.replaceOp (op, {result, trueVal});
0 commit comments