1515
1616#include " mlir/Dialect/Arith/IR/Arith.h"
1717#include " mlir/Dialect/EmitC/IR/EmitC.h"
18+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1819#include " mlir/IR/BuiltinAttributes.h"
1920#include " mlir/IR/BuiltinTypes.h"
2021#include " mlir/Transforms/DialectConversion.h"
@@ -35,8 +36,11 @@ class ArithConstantOpConversionPattern
3536 matchAndRewrite (arith::ConstantOp arithConst,
3637 arith::ConstantOp::Adaptor adaptor,
3738 ConversionPatternRewriter &rewriter) const override {
38- rewriter.replaceOpWithNewOp <emitc::ConstantOp>(
39- arithConst, arithConst.getType (), adaptor.getValue ());
39+ Type newTy = this ->getTypeConverter ()->convertType (arithConst.getType ());
40+ if (!newTy)
41+ return rewriter.notifyMatchFailure (arithConst, " type conversion failed" );
42+ rewriter.replaceOpWithNewOp <emitc::ConstantOp>(arithConst, newTy,
43+ adaptor.getValue ());
4044 return success ();
4145 }
4246};
@@ -51,6 +55,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
5155 return IntegerType::get (ty.getContext (), ty.getIntOrFloatBitWidth (),
5256 signedness);
5357 }
58+ } else if (emitc::isPointerWideType (ty)) {
59+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
60+ if (needsUnsigned)
61+ return emitc::SizeTType::get (ty.getContext ());
62+ return emitc::PtrDiffTType::get (ty.getContext ());
63+ }
5464 }
5565 return ty;
5666}
@@ -263,8 +273,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
263273 ConversionPatternRewriter &rewriter) const override {
264274
265275 Type type = adaptor.getLhs ().getType ();
266- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
267- return rewriter.notifyMatchFailure (op, " expected integer or index type" );
276+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
277+ return rewriter.notifyMatchFailure (
278+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
268279 }
269280
270281 bool needsUnsigned = needsUnsignedCmp (op.getPredicate ());
@@ -317,17 +328,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
317328 ConversionPatternRewriter &rewriter) const override {
318329
319330 Type opReturnType = this ->getTypeConverter ()->convertType (op.getType ());
320- if (!isa_and_nonnull<IntegerType>(opReturnType))
321- return rewriter.notifyMatchFailure (op, " expected integer result type" );
331+ if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
332+ emitc::isPointerWideType (opReturnType)))
333+ return rewriter.notifyMatchFailure (
334+ op, " expected integer or size_t/ssize_t/ptrdiff_t result type" );
322335
323336 if (adaptor.getOperands ().size () != 1 ) {
324337 return rewriter.notifyMatchFailure (
325338 op, " CastConversion only supports unary ops" );
326339 }
327340
328341 Type operandType = adaptor.getIn ().getType ();
329- if (!isa_and_nonnull<IntegerType>(operandType))
330- return rewriter.notifyMatchFailure (op, " expected integer operand type" );
342+ if (!operandType || !(isa<IntegerType>(operandType) ||
343+ emitc::isPointerWideType (operandType)))
344+ return rewriter.notifyMatchFailure (
345+ op, " expected integer or size_t/ssize_t/ptrdiff_t operand type" );
331346
332347 // Signed (sign-extending) casts from i1 are not supported.
333348 if (operandType.isInteger (1 ) && !castToUnsigned)
@@ -338,8 +353,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
338353 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
339354 // truncation.
340355 if (opReturnType.isInteger (1 )) {
356+ Type attrType = (emitc::isPointerWideType (operandType))
357+ ? rewriter.getIndexType ()
358+ : operandType;
341359 auto constOne = rewriter.create <emitc::ConstantOp>(
342- op.getLoc (), operandType, rewriter.getIntegerAttr (operandType, 1 ));
360+ op.getLoc (), operandType, rewriter.getOneAttr (attrType ));
343361 auto oneAndOperand = rewriter.create <emitc::BitwiseAndOp>(
344362 op.getLoc (), operandType, adaptor.getIn (), constOne);
345363 rewriter.replaceOpWithNewOp <emitc::CastOp>(op, opReturnType,
@@ -392,7 +410,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
392410 matchAndRewrite (ArithOp arithOp, typename ArithOp::Adaptor adaptor,
393411 ConversionPatternRewriter &rewriter) const override {
394412
395- rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, arithOp.getType (),
413+ Type newTy = this ->getTypeConverter ()->convertType (arithOp.getType ());
414+ if (!newTy)
415+ return rewriter.notifyMatchFailure (arithOp,
416+ " converting result type failed" );
417+ rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, newTy,
396418 adaptor.getOperands ());
397419
398420 return success ();
@@ -409,8 +431,9 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
409431 ConversionPatternRewriter &rewriter) const override {
410432
411433 Type type = this ->getTypeConverter ()->convertType (op.getType ());
412- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
413- return rewriter.notifyMatchFailure (op, " expected integer type" );
434+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
435+ return rewriter.notifyMatchFailure (
436+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
414437 }
415438
416439 if (type.isInteger (1 )) {
@@ -481,6 +504,89 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
481504 }
482505};
483506
507+ template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
508+ class ShiftOpConversion : public OpConversionPattern <ArithOp> {
509+ public:
510+ using OpConversionPattern<ArithOp>::OpConversionPattern;
511+
512+ LogicalResult
513+ matchAndRewrite (ArithOp op, typename ArithOp::Adaptor adaptor,
514+ ConversionPatternRewriter &rewriter) const override {
515+
516+ Type type = this ->getTypeConverter ()->convertType (op.getType ());
517+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
518+ return rewriter.notifyMatchFailure (
519+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
520+ }
521+
522+ if (type.isInteger (1 )) {
523+ return rewriter.notifyMatchFailure (op, " i1 type is not implemented" );
524+ }
525+
526+ Type arithmeticType = adaptIntegralTypeSignedness (type, isUnsignedOp);
527+
528+ Value lhs = adaptValueType (adaptor.getLhs (), rewriter, arithmeticType);
529+ // Shift amount interpreted as unsigned per Arith dialect spec.
530+ Type rhsType = adaptIntegralTypeSignedness (adaptor.getRhs ().getType (),
531+ /* needsUnsigned=*/ true );
532+ Value rhs = adaptValueType (adaptor.getRhs (), rewriter, rhsType);
533+
534+ // Add a runtime check for overflow
535+ Value width;
536+ if (emitc::isPointerWideType (type)) {
537+ Value eight = rewriter.create <emitc::ConstantOp>(
538+ op.getLoc (), rhsType, rewriter.getIndexAttr (8 ));
539+ emitc::CallOpaqueOp sizeOfCall = rewriter.create <emitc::CallOpaqueOp>(
540+ op.getLoc (), rhsType, " sizeof" , ArrayRef<Value>{eight});
541+ width = rewriter.create <emitc::MulOp>(op.getLoc (), rhsType, eight,
542+ sizeOfCall.getResult (0 ));
543+ } else {
544+ width = rewriter.create <emitc::ConstantOp>(
545+ op.getLoc (), rhsType,
546+ rewriter.getIntegerAttr (rhsType, type.getIntOrFloatBitWidth ()));
547+ }
548+
549+ Value excessCheck = rewriter.create <emitc::CmpOp>(
550+ op.getLoc (), rewriter.getI1Type (), emitc::CmpPredicate::lt, rhs, width);
551+
552+ // Any concrete value is a valid refinement of poison.
553+ Value poison = rewriter.create <emitc::ConstantOp>(
554+ op.getLoc (), arithmeticType,
555+ (isa<IntegerType>(arithmeticType)
556+ ? rewriter.getIntegerAttr (arithmeticType, 0 )
557+ : rewriter.getIndexAttr (0 )));
558+
559+ emitc::ExpressionOp ternary = rewriter.create <emitc::ExpressionOp>(
560+ op.getLoc (), arithmeticType, /* do_not_inline=*/ false );
561+ Block &bodyBlock = ternary.getBodyRegion ().emplaceBlock ();
562+ auto currentPoint = rewriter.getInsertionPoint ();
563+ rewriter.setInsertionPointToStart (&bodyBlock);
564+ Value arithmeticResult =
565+ rewriter.create <EmitCOp>(op.getLoc (), arithmeticType, lhs, rhs);
566+ Value resultOrPoison = rewriter.create <emitc::ConditionalOp>(
567+ op.getLoc (), arithmeticType, excessCheck, arithmeticResult, poison);
568+ rewriter.create <emitc::YieldOp>(op.getLoc (), resultOrPoison);
569+ rewriter.setInsertionPoint (op->getBlock (), currentPoint);
570+
571+ Value result = adaptValueType (ternary, rewriter, type);
572+
573+ rewriter.replaceOp (op, result);
574+ return success ();
575+ }
576+ };
577+
578+ template <typename ArithOp, typename EmitCOp>
579+ class SignedShiftOpConversion final
580+ : public ShiftOpConversion<ArithOp, EmitCOp, false > {
581+ using ShiftOpConversion<ArithOp, EmitCOp, false >::ShiftOpConversion;
582+ };
583+
584+ template <typename ArithOp, typename EmitCOp>
585+ class UnsignedShiftOpConversion final
586+ : public ShiftOpConversion<ArithOp, EmitCOp, true > {
587+ using ShiftOpConversion<ArithOp, EmitCOp, true >::ShiftOpConversion;
588+ };
589+
484590class SelectOpConversion : public OpConversionPattern <arith::SelectOp> {
485591public:
486592 using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -605,6 +711,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
605711 RewritePatternSet &patterns) {
606712 MLIRContext *ctx = patterns.getContext ();
607713
714+ mlir::populateEmitCSizeTTypeConversions (typeConverter);
715+
608716 // clang-format off
609717 patterns.add <
610718 ArithConstantOpConversionPattern,
@@ -620,6 +728,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
620728 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
621729 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
622730 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
731+ UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
732+ SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
733+ UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
623734 CmpFOpConversion,
624735 CmpIOpConversion,
625736 NegFOpConversion,
@@ -628,6 +739,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
628739 UnsignedCastConversion<arith::TruncIOp>,
629740 SignedCastConversion<arith::ExtSIOp>,
630741 UnsignedCastConversion<arith::ExtUIOp>,
742+ SignedCastConversion<arith::IndexCastOp>,
743+ UnsignedCastConversion<arith::IndexCastUIOp>,
631744 ItoFCastOpConversion<arith::SIToFPOp>,
632745 ItoFCastOpConversion<arith::UIToFPOp>,
633746 FtoICastOpConversion<arith::FPToSIOp>,
0 commit comments