-
Couldn't load subscription status.
- Fork 15k
[mlir][spirv] Handle signedness casts #155388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Lower unrealized_conversion_cast of signed/unsigned/signless integer types of the same size to spirv.Bitcast. arith.bitcast is specifically for signless types, hence it is not used for such casts and unrealized_conversion_cast is used instead. Co-authored-by: Thomas Preud'homme <[email protected]>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Michael Platings (mplatings) ChangesLower unrealized_conversion_cast of signed/unsigned/signless integer types of the same size to spirv.Bitcast. arith.bitcast is specifically for signless types, hence it is not used for such casts and unrealized_conversion_cast is used instead. Full diff: https://github.com/llvm/llvm-project/pull/155388.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..ee694104dc918 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
llvm_unreachable("Unhandled rounding mode");
}
+static bool isSignednessCast(Type srcType, Type dstType) {
+ if (srcType.isInteger() && dstType.isInteger()) {
+ return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+ }
+ if (isa<VectorType>(srcType) && isa<VectorType>(dstType)) {
+ return isSignednessCast(cast<VectorType>(srcType).getElementType(),
+ cast<VectorType>(dstType).getElementType());
+ }
+ return false;
+}
+
/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
- Type dstType = this->getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
+ TypeRange dstTypes;
+ SmallVector<Type> newDstTypes;
+ SmallVector<Value> unrealizedConvCastSrcs;
+ SmallVector<Type> unrealizedConvCastDstTypes;
+ constexpr bool isUnrealizedConvCast =
+ std::is_same_v<Op, UnrealizedConversionCastOp>;
+ if constexpr (isUnrealizedConvCast)
+ dstTypes = op.getOutputs().getTypes();
+ else
+ dstTypes = op.getType();
+ LogicalResult matched = failure();
+ for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), dstTypes)) {
+ Type srcType = src.getType();
+ // Use UnrealizedConversionCast as the bridge so that we don't need to
+ // pull in patterns for other dialects.
+ if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
+ newDstTypes.push_back(dstType);
+ unrealizedConvCastSrcs.push_back(src);
+ unrealizedConvCastDstTypes.push_back(dstType);
+ continue;
+ }
+ dstType = this->getTypeConverter()->convertType(dstType);
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+ return failure();
+ matched = success();
+ newDstTypes.push_back(dstType);
+ }
- if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+ if (failed(matched))
return failure();
- if (dstType == srcType) {
- // Due to type conversion, we are seeing the same source and target type.
- // Then we can just erase this operation by forwarding its operand.
- rewriter.replaceOp(op, adaptor.getOperands().front());
- } else {
- // Compute new rounding mode (if any).
- std::optional<spirv::FPRoundingMode> rm = std::nullopt;
- if (auto roundingModeOp =
- dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
- if (arith::RoundingModeAttr roundingMode =
- roundingModeOp.getRoundingModeAttr()) {
- if (!(rm =
- convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
- return rewriter.notifyMatchFailure(
- op->getLoc(),
- llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
- }
+ // Compute new rounding mode (if any).
+ Location loc = op->getLoc();
+ std::optional<spirv::FPRoundingMode> rm = std::nullopt;
+ if (auto roundingModeOp =
+ dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+ if (arith::RoundingModeAttr roundingMode =
+ roundingModeOp.getRoundingModeAttr()) {
+ if (!(rm = convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
+ return rewriter.notifyMatchFailure(
+ loc,
+ llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
}
}
- // Create replacement op and attach rounding mode attribute (if any).
- auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
- op, dstType, adaptor.getOperands());
- if (rm) {
- newOp->setAttr(
- getDecorationString(spirv::Decoration::FPRoundingMode),
- spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+ }
+
+ llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap;
+ if (!unrealizedConvCastSrcs.empty()) {
+ auto newOp = rewriter.create<UnrealizedConversionCastOp>(
+ loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs);
+ for (auto [src, dst] :
+ llvm::zip(unrealizedConvCastSrcs, newOp.getResults()))
+ unrealizedConvCastSrcDstMap[src] = dst;
+ }
+
+ SmallVector<Value> newValues;
+ for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), newDstTypes)) {
+ Type srcType = src.getType();
+ if (dstType == srcType) {
+ // Due to type conversion, we are seeing the same source and target
+ // type. Then we can just erase this operation by forwarding its
+ // operand.
+ newValues.push_back(src);
+ } else if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
+ newValues.push_back(unrealizedConvCastSrcDstMap[src]);
+ } else {
+ // Create replacement op and attach rounding mode attribute (if any).
+ auto newOp = rewriter.template create<SPIRVOp>(loc, dstType, src);
+ if (rm) {
+ newOp->setAttr(
+ getDecorationString(spirv::Decoration::FPRoundingMode),
+ spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+ }
+ newValues.push_back(newOp.getResult());
}
}
+ rewriter.replaceOp(op, newValues);
return success();
}
};
@@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
+ TypeCastingOpPattern<UnrealizedConversionCastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
AddUIExtendedOpPattern,
@@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
- // in patterns for other dialects.
- target->addLegalOp<UnrealizedConversionCastOp>();
+ // in patterns for other dialects. If the UnrealizedConversionCast is
+ // between integers of the same bitwidth, it is either a nop or a
+ // signedness cast which the corresponding pattern convert to Bitcast.
+ target->addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+ [&](UnrealizedConversionCastOp op) {
+ for (auto [srcType, dstType] :
+ llvm::zip(op.getOperandTypes(), op.getResultTypes()))
+ if (isSignednessCast(srcType, dstType))
+ return false;
+ return true;
+ });
// Fail hard when there are any remaining 'arith' ops.
target->addIllegalDialect<arith::ArithDialect>();
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..b9a4232758a17 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -743,6 +743,35 @@ func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
return
}
+// CHECK-LABEL: @unrealized_conversion_cast
+func.func @unrealized_conversion_cast(%arg0: vector<3xi64>, %arg1: i16, %arg2: f32) {
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
+ %0 = builtin.unrealized_conversion_cast %arg0 : vector<3xi64> to vector<3xui64>
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
+ %1 = builtin.unrealized_conversion_cast %arg1 : i16 to ui16
+
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
+ %2:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, ui16
+
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
+ %3:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xi64>, ui16
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
+ %4:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, i16
+
+ // bitcast from float to int should be represented using arith.bitcast
+ // CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to i32
+ %5 = builtin.unrealized_conversion_cast %arg2 : f32 to i32
+
+ // test mixed signedness and non-signedness cast
+ // CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to f16
+ // CHECK-NEXT: spirv.Bitcast %{{.+}} : i32 to ui32
+ %6:2 = builtin.unrealized_conversion_cast %5, %arg2 : i32, f32 to ui32, f16
+
+ // CHECK-NEXT: return
+ return
+}
+
// CHECK-LABEL: @fpext1
func.func @fpext1(%arg0: f16) -> f64 {
// CHECK: spirv.FConvert %{{.*}} : f16 to f64
|
| unrealizedConvCastDstTypes.push_back(dstType); | ||
| continue; | ||
| } | ||
| dstType = this->getTypeConverter()->convertType(dstType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could it make sense to create a new variable newDstType to avoid confusion?
| } | ||
|
|
||
| if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) | ||
| if (failed(matched)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we get rid of the matched variable, and instead create a const bool variable after the loop to check if !newDstTypes.empty() && (newDstTypes.size() > unrealizedConvCastDstTypes.size()) here, or am I missing something?
| spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); | ||
| } | ||
|
|
||
| llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could it make sense to add a comment explaining what's happening here? E.g. "Recreate unrealized_conversion_cast ops for unhandled casts" (if I understood the logic correctly)?
|
Thanks for the PR! Looks mostly good to me, I left just a few comments :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What produces these unrealized_conversion_casts? Could you show the input IR and the result of conversion to SPIR-V that runs into the signess issue?
It's not clear to be we should have dedicated handling for unrealized_conversion_casts.
One instance is #141096. The problem I'm encountering is around TOSA rescale because there's tooling that wants to use tensors of unsigned type if input_unsigned or output_unsigned are true. |
This fix looks fishy to me, I wouldn't not expect any unrealized_casts as an input for conversion to spirv. A motivating example at the level of arith/vector/sfc or the test dialect. would help me understand if this needs handling. |
I can create an arith example by running which gives you: Converting that with So there's a problem at (Apologies if I've misunderstood you, I'm relatively new to MLIR) |
I think the real issue is that step 2 (after tosa to linalg) has an unresolved cast. In general, these are supposed to be inserted automatically by the dialect conversion driver and eventually cancel out. Once you start emitting those manually, this assumption may no longer hold. |
I kind of agree, but it seems that unrealized_conversion_cast is the only way to represent signedness casts, given that arith.bitcast is explicitly only for signless casts. Any advice for how to proceed? |
|
Can you can add a new op to represent signedness casts? This seems useful in general to get signed/unsigned values in and out of arith. |
|
Potentially. Do you have any idea why arith.bitcast is specifically signless? The easy thing to do seems to be to relax that constraint |
I think it would be worth discussing on discourse. This has probably been considered before, but I'm not aware of any specific discussions. |
|
I found this old discussion: https://discourse.llvm.org/t/rfc-signednesscastop/3253. From that I conclude that signs should be removed from types before lowering to arith et al. And I agree that #141096 doesn't look right. |
Lower unrealized_conversion_cast of signed/unsigned/signless integer types of the same size to spirv.Bitcast.
arith.bitcast is specifically for signless types, hence it is not used for such casts and unrealized_conversion_cast is used instead.