@@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
856856 llvm_unreachable (" Unhandled rounding mode" );
857857}
858858
859+ static bool isSignednessCast (Type srcType, Type dstType) {
860+ if (srcType.isInteger () && dstType.isInteger ()) {
861+ return srcType.getIntOrFloatBitWidth () == dstType.getIntOrFloatBitWidth ();
862+ }
863+ if (isa<VectorType>(srcType) && isa<VectorType>(dstType)) {
864+ return isSignednessCast (cast<VectorType>(srcType).getElementType (),
865+ cast<VectorType>(dstType).getElementType ());
866+ }
867+ return false ;
868+ }
869+
859870// / Converts type-casting standard operations to SPIR-V operations.
860871template <typename Op, typename SPIRVOp>
861872struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
864875 LogicalResult
865876 matchAndRewrite (Op op, typename Op::Adaptor adaptor,
866877 ConversionPatternRewriter &rewriter) const override {
867- Type srcType = llvm::getSingleElement (adaptor.getOperands ()).getType ();
868- Type dstType = this ->getTypeConverter ()->convertType (op.getType ());
869- if (!dstType)
870- return getTypeConversionFailure (rewriter, op);
878+ TypeRange dstTypes;
879+ SmallVector<Type> newDstTypes;
880+ SmallVector<Value> unrealizedConvCastSrcs;
881+ SmallVector<Type> unrealizedConvCastDstTypes;
882+ constexpr bool isUnrealizedConvCast =
883+ std::is_same_v<Op, UnrealizedConversionCastOp>;
884+ if constexpr (isUnrealizedConvCast)
885+ dstTypes = op.getOutputs ().getTypes ();
886+ else
887+ dstTypes = op.getType ();
888+ LogicalResult matched = failure ();
889+ for (auto [src, dstType] : llvm::zip (adaptor.getOperands (), dstTypes)) {
890+ Type srcType = src.getType ();
891+ // Use UnrealizedConversionCast as the bridge so that we don't need to
892+ // pull in patterns for other dialects.
893+ if (isUnrealizedConvCast && !isSignednessCast (srcType, dstType)) {
894+ newDstTypes.push_back (dstType);
895+ unrealizedConvCastSrcs.push_back (src);
896+ unrealizedConvCastDstTypes.push_back (dstType);
897+ continue ;
898+ }
899+ dstType = this ->getTypeConverter ()->convertType (dstType);
900+ if (!dstType)
901+ return getTypeConversionFailure (rewriter, op);
902+
903+ if (isBoolScalarOrVector (srcType) || isBoolScalarOrVector (dstType))
904+ return failure ();
905+ matched = success ();
906+ newDstTypes.push_back (dstType);
907+ }
871908
872- if (isBoolScalarOrVector (srcType) || isBoolScalarOrVector (dstType ))
909+ if (failed (matched ))
873910 return failure ();
874911
875- if (dstType == srcType) {
876- // Due to type conversion, we are seeing the same source and target type.
877- // Then we can just erase this operation by forwarding its operand.
878- rewriter.replaceOp (op, adaptor.getOperands ().front ());
879- } else {
880- // Compute new rounding mode (if any).
881- std::optional<spirv::FPRoundingMode> rm = std::nullopt ;
882- if (auto roundingModeOp =
883- dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
884- if (arith::RoundingModeAttr roundingMode =
885- roundingModeOp.getRoundingModeAttr ()) {
886- if (!(rm =
887- convertArithRoundingModeToSPIRV (roundingMode.getValue ()))) {
888- return rewriter.notifyMatchFailure (
889- op->getLoc (),
890- llvm::formatv (" unsupported rounding mode '{0}'" , roundingMode));
891- }
912+ // Compute new rounding mode (if any).
913+ Location loc = op->getLoc ();
914+ std::optional<spirv::FPRoundingMode> rm = std::nullopt ;
915+ if (auto roundingModeOp =
916+ dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
917+ if (arith::RoundingModeAttr roundingMode =
918+ roundingModeOp.getRoundingModeAttr ()) {
919+ if (!(rm = convertArithRoundingModeToSPIRV (roundingMode.getValue ()))) {
920+ return rewriter.notifyMatchFailure (
921+ loc,
922+ llvm::formatv (" unsupported rounding mode '{0}'" , roundingMode));
892923 }
893924 }
894- // Create replacement op and attach rounding mode attribute (if any).
895- auto newOp = rewriter.template replaceOpWithNewOp <SPIRVOp>(
896- op, dstType, adaptor.getOperands ());
897- if (rm) {
898- newOp->setAttr (
899- getDecorationString (spirv::Decoration::FPRoundingMode),
900- spirv::FPRoundingModeAttr::get (rewriter.getContext (), *rm));
925+ }
926+
927+ llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap;
928+ if (!unrealizedConvCastSrcs.empty ()) {
929+ auto newOp = rewriter.create <UnrealizedConversionCastOp>(
930+ loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs);
931+ for (auto [src, dst] :
932+ llvm::zip (unrealizedConvCastSrcs, newOp.getResults ()))
933+ unrealizedConvCastSrcDstMap[src] = dst;
934+ }
935+
936+ SmallVector<Value> newValues;
937+ for (auto [src, dstType] : llvm::zip (adaptor.getOperands (), newDstTypes)) {
938+ Type srcType = src.getType ();
939+ if (dstType == srcType) {
940+ // Due to type conversion, we are seeing the same source and target
941+ // type. Then we can just erase this operation by forwarding its
942+ // operand.
943+ newValues.push_back (src);
944+ } else if (isUnrealizedConvCast && !isSignednessCast (srcType, dstType)) {
945+ newValues.push_back (unrealizedConvCastSrcDstMap[src]);
946+ } else {
947+ // Create replacement op and attach rounding mode attribute (if any).
948+ auto newOp = rewriter.template create <SPIRVOp>(loc, dstType, src);
949+ if (rm) {
950+ newOp->setAttr (
951+ getDecorationString (spirv::Decoration::FPRoundingMode),
952+ spirv::FPRoundingModeAttr::get (rewriter.getContext (), *rm));
953+ }
954+ newValues.push_back (newOp.getResult ());
901955 }
902956 }
957+ rewriter.replaceOp (op, newValues);
903958 return success ();
904959 }
905960};
@@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
13311386 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
13321387 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
13331388 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1389+ TypeCastingOpPattern<UnrealizedConversionCastOp, spirv::BitcastOp>,
13341390 CmpIOpBooleanPattern, CmpIOpPattern,
13351391 CmpFOpNanNonePattern, CmpFOpPattern,
13361392 AddUIExtendedOpPattern,
@@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass
13851441 SPIRVTypeConverter typeConverter (targetAttr, options);
13861442
13871443 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1388- // in patterns for other dialects.
1389- target->addLegalOp <UnrealizedConversionCastOp>();
1444+ // in patterns for other dialects. If the UnrealizedConversionCast is
1445+ // between integers of the same bitwidth, it is either a nop or a
1446+ // signedness cast which the corresponding pattern convert to Bitcast.
1447+ target->addDynamicallyLegalOp <UnrealizedConversionCastOp>(
1448+ [&](UnrealizedConversionCastOp op) {
1449+ for (auto [srcType, dstType] :
1450+ llvm::zip (op.getOperandTypes (), op.getResultTypes ()))
1451+ if (isSignednessCast (srcType, dstType))
1452+ return false ;
1453+ return true ;
1454+ });
13901455
13911456 // Fail hard when there are any remaining 'arith' ops.
13921457 target->addIllegalDialect <arith::ArithDialect>();
0 commit comments