-
Couldn't load subscription status.
- Fork 15k
[mlir][spirv] Handle signedness casts #155388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) { | |
| llvm_unreachable("Unhandled rounding mode"); | ||
| } | ||
|
|
||
| static bool isSignednessCast(Type srcType, Type dstType) { | ||
| if (srcType.isInteger() && dstType.isInteger()) { | ||
| return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); | ||
| } | ||
| if (isa<VectorType>(srcType) && isa<VectorType>(dstType)) { | ||
| return isSignednessCast(cast<VectorType>(srcType).getElementType(), | ||
| cast<VectorType>(dstType).getElementType()); | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| /// Converts type-casting standard operations to SPIR-V operations. | ||
| template <typename Op, typename SPIRVOp> | ||
| struct TypeCastingOpPattern final : public OpConversionPattern<Op> { | ||
|
|
@@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> { | |
| LogicalResult | ||
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType(); | ||
| Type dstType = this->getTypeConverter()->convertType(op.getType()); | ||
| if (!dstType) | ||
| return getTypeConversionFailure(rewriter, op); | ||
| TypeRange dstTypes; | ||
| SmallVector<Type> newDstTypes; | ||
| SmallVector<Value> unrealizedConvCastSrcs; | ||
| SmallVector<Type> unrealizedConvCastDstTypes; | ||
| constexpr bool isUnrealizedConvCast = | ||
| std::is_same_v<Op, UnrealizedConversionCastOp>; | ||
| if constexpr (isUnrealizedConvCast) | ||
| dstTypes = op.getOutputs().getTypes(); | ||
| else | ||
| dstTypes = op.getType(); | ||
| LogicalResult matched = failure(); | ||
| for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), dstTypes)) { | ||
| Type srcType = src.getType(); | ||
| // Use UnrealizedConversionCast as the bridge so that we don't need to | ||
| // pull in patterns for other dialects. | ||
| if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) { | ||
| newDstTypes.push_back(dstType); | ||
| unrealizedConvCastSrcs.push_back(src); | ||
| unrealizedConvCastDstTypes.push_back(dstType); | ||
| continue; | ||
| } | ||
| dstType = this->getTypeConverter()->convertType(dstType); | ||
| if (!dstType) | ||
| return getTypeConversionFailure(rewriter, op); | ||
|
|
||
| if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) | ||
| return failure(); | ||
| matched = success(); | ||
| newDstTypes.push_back(dstType); | ||
| } | ||
|
|
||
| if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) | ||
| if (failed(matched)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we get rid of the |
||
| return failure(); | ||
|
|
||
| if (dstType == srcType) { | ||
| // Due to type conversion, we are seeing the same source and target type. | ||
| // Then we can just erase this operation by forwarding its operand. | ||
| rewriter.replaceOp(op, adaptor.getOperands().front()); | ||
| } else { | ||
| // Compute new rounding mode (if any). | ||
| std::optional<spirv::FPRoundingMode> rm = std::nullopt; | ||
| if (auto roundingModeOp = | ||
| dyn_cast<arith::ArithRoundingModeInterface>(*op)) { | ||
| if (arith::RoundingModeAttr roundingMode = | ||
| roundingModeOp.getRoundingModeAttr()) { | ||
| if (!(rm = | ||
| convertArithRoundingModeToSPIRV(roundingMode.getValue()))) { | ||
| return rewriter.notifyMatchFailure( | ||
| op->getLoc(), | ||
| llvm::formatv("unsupported rounding mode '{0}'", roundingMode)); | ||
| } | ||
| // Compute new rounding mode (if any). | ||
| Location loc = op->getLoc(); | ||
| std::optional<spirv::FPRoundingMode> rm = std::nullopt; | ||
| if (auto roundingModeOp = | ||
| dyn_cast<arith::ArithRoundingModeInterface>(*op)) { | ||
| if (arith::RoundingModeAttr roundingMode = | ||
| roundingModeOp.getRoundingModeAttr()) { | ||
| if (!(rm = convertArithRoundingModeToSPIRV(roundingMode.getValue()))) { | ||
| return rewriter.notifyMatchFailure( | ||
| loc, | ||
| llvm::formatv("unsupported rounding mode '{0}'", roundingMode)); | ||
| } | ||
| } | ||
| // Create replacement op and attach rounding mode attribute (if any). | ||
| auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>( | ||
| op, dstType, adaptor.getOperands()); | ||
| if (rm) { | ||
| newOp->setAttr( | ||
| getDecorationString(spirv::Decoration::FPRoundingMode), | ||
| spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); | ||
| } | ||
|
|
||
| llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could it make sense to add a comment explaining what's happening here? E.g. "Recreate unrealized_conversion_cast ops for unhandled casts" (if I understood the logic correctly)? |
||
| if (!unrealizedConvCastSrcs.empty()) { | ||
| auto newOp = rewriter.create<UnrealizedConversionCastOp>( | ||
| loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs); | ||
| for (auto [src, dst] : | ||
| llvm::zip(unrealizedConvCastSrcs, newOp.getResults())) | ||
| unrealizedConvCastSrcDstMap[src] = dst; | ||
| } | ||
|
|
||
| SmallVector<Value> newValues; | ||
| for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), newDstTypes)) { | ||
| Type srcType = src.getType(); | ||
| if (dstType == srcType) { | ||
| // Due to type conversion, we are seeing the same source and target | ||
| // type. Then we can just erase this operation by forwarding its | ||
| // operand. | ||
| newValues.push_back(src); | ||
| } else if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) { | ||
| newValues.push_back(unrealizedConvCastSrcDstMap[src]); | ||
| } else { | ||
| // Create replacement op and attach rounding mode attribute (if any). | ||
| auto newOp = rewriter.template create<SPIRVOp>(loc, dstType, src); | ||
| if (rm) { | ||
| newOp->setAttr( | ||
| getDecorationString(spirv::Decoration::FPRoundingMode), | ||
| spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); | ||
| } | ||
| newValues.push_back(newOp.getResult()); | ||
| } | ||
| } | ||
| rewriter.replaceOp(op, newValues); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
@@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns( | |
| TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, | ||
| TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>, | ||
| TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, | ||
| TypeCastingOpPattern<UnrealizedConversionCastOp, spirv::BitcastOp>, | ||
| CmpIOpBooleanPattern, CmpIOpPattern, | ||
| CmpFOpNanNonePattern, CmpFOpPattern, | ||
| AddUIExtendedOpPattern, | ||
|
|
@@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass | |
| SPIRVTypeConverter typeConverter(targetAttr, options); | ||
|
|
||
| // Use UnrealizedConversionCast as the bridge so that we don't need to pull | ||
| // in patterns for other dialects. | ||
| target->addLegalOp<UnrealizedConversionCastOp>(); | ||
| // in patterns for other dialects. If the UnrealizedConversionCast is | ||
| // between integers of the same bitwidth, it is either a nop or a | ||
| // signedness cast which the corresponding pattern convert to Bitcast. | ||
| target->addDynamicallyLegalOp<UnrealizedConversionCastOp>( | ||
| [&](UnrealizedConversionCastOp op) { | ||
| for (auto [srcType, dstType] : | ||
| llvm::zip(op.getOperandTypes(), op.getResultTypes())) | ||
| if (isSignednessCast(srcType, dstType)) | ||
| return false; | ||
| return true; | ||
| }); | ||
|
|
||
| // Fail hard when there are any remaining 'arith' ops. | ||
| target->addIllegalDialect<arith::ArithDialect>(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could it make sense to create a new variable
newDstTypeto avoid confusion?