diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index a68c0153df443..8b6b553f6eed0 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -262,12 +262,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { .Default([](auto) { return std::nullopt; }); } - static std::optional getFuncName(gpu::ShuffleOp op) { - StringRef baseName = getBaseName(op.getMode()); - std::optional typeMangling = getTypeMangling(op.getType(0)); + static std::optional getFuncName(gpu::ShuffleMode mode, + Type type) { + StringRef baseName = getBaseName(mode); + std::optional typeMangling = getTypeMangling(type); if (!typeMangling) return std::nullopt; - return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName, + return llvm::formatv("_Z{}{}{}", baseName.size(), baseName, typeMangling.value()); } @@ -286,6 +287,37 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { val == getSubgroupSize(op); } + static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc, + ConversionPatternRewriter &rewriter) { + return TypeSwitch(oldVal.getType()) + .Case([&](BFloat16Type) { + return rewriter.create(loc, rewriter.getI16Type(), + oldVal); + }) + .Case([&](IntegerType intTy) -> Value { + if (intTy.getWidth() == 1) + return rewriter.create(loc, rewriter.getI8Type(), + oldVal); + return oldVal; + }) + .Default(oldVal); + } + + static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy, + Location loc, + ConversionPatternRewriter &rewriter) { + return TypeSwitch(newTy) + .Case([&](BFloat16Type) { + return rewriter.create(loc, newTy, oldVal); + }) + .Case([&](IntegerType intTy) -> Value { + if (intTy.getWidth() == 1) + return rewriter.create(loc, newTy, oldVal); + return oldVal; + }) + .Default(oldVal); + } + LogicalResult matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { @@ -293,26 +325,32 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure( op, "shuffle width and subgroup size mismatch"); - std::optional funcName = getFuncName(op); + Location loc = op->getLoc(); + Value inValue = + bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter); + std::optional funcName = + getFuncName(op.getMode(), inValue.getType()); if (!funcName) return rewriter.notifyMatchFailure(op, "unsupported value type"); Operation *moduleOp = op->getParentWithTrait(); assert(moduleOp && "Expecting module"); - Type valueType = adaptor.getValue().getType(); + Type valueType = inValue.getType(); Type offsetType = adaptor.getOffset().getType(); Type resultType = valueType; LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( moduleOp, funcName.value(), {valueType, offsetType}, resultType, /*isMemNone=*/false, /*isConvergent=*/true); - Location loc = op->getLoc(); - std::array args{adaptor.getValue(), adaptor.getOffset()}; + std::array args{inValue, adaptor.getOffset()}; Value result = createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); + Value resultOrConversion = + bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter); + Value trueVal = rewriter.create(loc, rewriter.getI1Type(), true); - rewriter.replaceOp(op, {result, trueVal}); + rewriter.replaceOp(op, {resultOrConversion, trueVal}); return success(); } }; diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir index e75225d6d54f5..c2930971dbcf9 100644 --- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir +++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir @@ -279,7 +279,8 @@ gpu.module @shuffles { // CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16, // CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64, // CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32, - // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32) + // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16, + // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32) llvm.func @gpu_shuffles(%i8_val: i8, %i16_val: i16, %i32_val: i32, @@ -287,6 +288,8 @@ gpu.module @shuffles { %f16_val: f16, %f32_val: f32, %f64_val: f64, + %bf16_val: bf16, + %i1_val: i1, %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} { %width = arith.constant 16 : i32 // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]]) @@ -303,6 +306,14 @@ gpu.module @shuffles { // CHECK: llvm.mlir.constant(true) : i1 // CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]]) // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16 + // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]]) + // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8 + // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9) + // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1 + // CHECK: llvm.mlir.constant(true) : i1 %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8 %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16 %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32 @@ -310,6 +321,8 @@ gpu.module @shuffles { %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16 %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32 %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64 + %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16 + %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1 llvm.return } } @@ -344,10 +357,10 @@ gpu.module @shuffles_mismatch { // Cannot convert due to value type not being supported by the conversion gpu.module @not_supported_lowering { - llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} { + llvm.func @gpu_shuffles(%val: f128, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} { %width = arith.constant 32 : i32 // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}} - %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1 + %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : f128 llvm.return } }