diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 7ffc861331760..3f6215458f90c 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -65,11 +65,8 @@ class AttrConvertFastMathToLLVM { convertArithFastMathAttrToLLVM(arithFMFAttr)); } } - ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } + Attribute getPropAttr() const { return {}; } private: NamedAttrList convertedAttr; @@ -82,23 +79,37 @@ template class AttrConvertOverflowToLLVM { public: AttrConvertOverflowToLLVM(SourceOp srcOp) { + using IntegerOverflowFlagsAttr = LLVM::IntegerOverflowFlagsAttr; + // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; // Get the name of the arith overflow attribute. StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); - // Remove the source overflow attribute. + // Remove the source overflow attribute from the set that will be present + // in the target. if (auto arithAttr = dyn_cast_if_present( convertedAttr.erase(arithAttrName))) { - overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); + auto llvmFlag = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); + // Create a dictionary attribute holding the overflow flags property. + // (In the LLVM dialect, the overflow flags are a property, not an + // attribute.) + MLIRContext *ctx = srcOp.getOperation()->getContext(); + Builder b(ctx); + auto llvmFlagAttr = IntegerOverflowFlagsAttr::get(ctx, llvmFlag); + StringRef llvmAttrName = TargetOp::getOverflowFlagsAttrName(); + SmallVector attrs; + attrs.push_back(b.getNamedAttr(llvmAttrName, llvmFlagAttr)); + // Set the properties attribute of the operation state so that the + // property can be updated when the operation is created. + propertiesAttr = b.getDictionaryAttr(attrs); } } - ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } + Attribute getPropAttr() const { return propertiesAttr; } private: NamedAttrList convertedAttr; - LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; + DictionaryAttr propertiesAttr; }; template @@ -129,9 +140,7 @@ class AttrConverterConstrainedFPToLLVM { } ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } + Attribute getPropAttr() const { return {}; } private: NamedAttrList convertedAttr; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index c292e3727f46c..f8e0ccc093f8b 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -19,16 +19,14 @@ class CallOpInterface; namespace LLVM { namespace detail { -/// Handle generically setting flags as native properties on LLVM operations. -void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags); - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult oneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); +LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); /// Replaces the given operation "op" with a call to an LLVM intrinsic with the /// specified name "intrinsic" and operands. @@ -307,9 +305,9 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - adaptor.getOperands(), op->getAttrs(), - *this->getTypeConverter(), rewriter); + return LLVM::detail::oneToOneRewrite( + op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index cad6cec761ab8..2cd3943ce02d9 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -54,25 +54,26 @@ LogicalResult handleMultidimensionalVectors( std::function createOperand, ConversionPatternRewriter &rewriter); -LogicalResult vectorOneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); +LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM // Default attribute conversion class, which passes all source attributes -// through to the target op, unmodified. +// through to the target op, unmodified. The attribute to set properties of the +// target operation will be nullptr (i.e. any properties that exist in will have +// default values). template class AttrConvertPassThrough { public: AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} ArrayRef getAttrs() const { return srcAttrs; } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } + Attribute getPropAttr() const { return {}; } private: ArrayRef srcAttrs; @@ -80,10 +81,13 @@ class AttrConvertPassThrough { /// Basic lowering implementation to rewrite Ops with just one result to the /// LLVM Dialect. This supports higher-dimensional vector types. -/// The AttrConvert template template parameter should be a template class -/// with SourceOp and TargetOp type parameters, a constructor that takes -/// a SourceOp instance, and a getAttrs() method that returns -/// ArrayRef. +/// The AttrConvert template template parameter should: +// - be a template class with SourceOp and TargetOp type parameters +// - have a constructor that takes a SourceOp instance +// - a getAttrs() method that returns ArrayRef containing +// attributes that the target operation will have +// - a getPropAttr() method that returns either a NULL attribute or a +// DictionaryAttribute with properties that exist for the target operation template typename AttrConvert = AttrConvertPassThrough> @@ -134,8 +138,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), adaptor.getOperands(), - attrConvert.getAttrs(), *this->getTypeConverter(), rewriter, - attrConvert.getOverflowFlags()); + attrConvert.getAttrs(), attrConvert.getPropAttr(), + *this->getTypeConverter(), rewriter); } }; } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index a38cf41a3e09b..14fcaa3968f10 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -158,6 +158,18 @@ class Arith_IntBinaryOpWithOverflowFlags traits = [ attr-dict `:` type($result) }]; } +class Arith_IntBinaryOpWithExactFlag traits = []> : + Arith_BinaryOp]>, + Arguments<(ins SignlessIntegerOrIndexLike:$lhs, + SignlessIntegerOrIndexLike:$rhs, + UnitAttr:$isExact)>, + Results<(outs SignlessIntegerOrIndexLike:$result)> { + + let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)? + attr-dict `:` type($result) }]; +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -482,7 +494,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative, // DivUIOp //===----------------------------------------------------------------------===// -def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { +def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui", + [ConditionallySpeculatable]> { let summary = "unsigned integer division operation"; let description = [{ Unsigned integer division. Rounds towards zero. Treats the leading bit as @@ -493,12 +506,18 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { `tensor` values, the behavior is undefined if _any_ elements are divided by zero. + If the `exact` attribute is present, the result value is poison if `lhs` is + not a multiple of `rhs`. + Example: ```mlir // Scalar unsigned integer division. %a = arith.divui %b, %c : i64 + // Scalar unsigned integer division where %b is known to be a multiple of %c. + %a = arith.divui %b, %c exact : i64 + // SIMD vector element-wise division. %f = arith.divui %g, %h : vector<4xi32> @@ -519,7 +538,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { // DivSIOp //===----------------------------------------------------------------------===// -def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> { +def Arith_DivSIOp : Arith_IntBinaryOpWithExactFlag<"divsi", + [ConditionallySpeculatable]> { let summary = "signed integer division operation"; let description = [{ Signed integer division. Rounds towards zero. Treats the leading bit as @@ -530,12 +550,18 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> { behavior is undefined if _any_ of its elements are divided by zero or has a signed division overflow. + If the `exact` attribute is present, the result value is poison if `lhs` is + not a multiple of `rhs`. + Example: ```mlir // Scalar signed integer division. %a = arith.divsi %b, %c : i64 + // Scalar signed integer division where %b is known to be a multiple of %c. + %a = arith.divsi %b, %c exact : i64 + // SIMD vector element-wise division. %f = arith.divsi %g, %h : vector<4xi32> @@ -821,7 +847,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> { // ShRUIOp //===----------------------------------------------------------------------===// -def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { +def Arith_ShRUIOp : Arith_IntBinaryOpWithExactFlag<"shrui", [Pure]> { let summary = "unsigned integer right-shift"; let description = [{ The `shrui` operation shifts an integer value of the first operand to the right @@ -830,12 +856,17 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { filled with zeros. If the value of the second operand is greater or equal than the bitwidth of the first operand, then the operation returns poison. + If the `exact` keyword is present, the result value of shrui is a poison + value if any of the bits shifted out are non-zero. + Example: ```mlir - %1 = arith.constant 160 : i8 // %1 is 0b10100000 + %1 = arith.constant 160 : i8 // %1 is 0b10100000 %2 = arith.constant 3 : i8 - %3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 + %3 = arith.constant 6 : i8 + %4 = arith.shrui %1, %2 exact : i8 // %4 is 0b00010100 + %5 = arith.shrui %1, %3 : i8 // %3 is 0b00000010 ``` }]; let hasFolder = 1; @@ -845,7 +876,7 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { // ShRSIOp //===----------------------------------------------------------------------===// -def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { +def Arith_ShRSIOp : Arith_IntBinaryOpWithExactFlag<"shrsi", [Pure]> { let summary = "signed integer right-shift"; let description = [{ The `shrsi` operation shifts an integer value of the first operand to the right @@ -856,14 +887,17 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { operand is greater or equal than bitwidth of the first operand, then the operation returns poison. + If the `exact` keyword is present, the result value of shrsi is a poison + value if any of the bits shifted out are non-zero. + Example: ```mlir - %1 = arith.constant 160 : i8 // %1 is 0b10100000 + %1 = arith.constant 160 : i8 // %1 is 0b10100000 %2 = arith.constant 3 : i8 - %3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 - %4 = arith.constant 96 : i8 // %4 is 0b01100000 - %5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 + %3 = arith.shrsi %1, %2 exact : i8 // %3 is 0b11110100 + %4 = arith.constant 98 : i8 // %4 is 0b01100010 + %5 = arith.shrsi %4, %2 : i8 // %5 is 0b00001100 ``` }]; let hasFolder = 1; diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 03ed4d51cc744..7f6eb1a5d764d 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -258,6 +258,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propAttr=*/Attribute{}, *getTypeConverter(), rewriter); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index 86d02e6c6209f..6a0c21185983e 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - op->getAttrs(), *getTypeConverter(), rewriter); + op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(), + rewriter); } }; diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 48a03198fd465..f28a6ccb42455 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Detail methods //===----------------------------------------------------------------------===// -void LLVM::detail::setNativeProperties(Operation *op, - IntegerOverflowFlags overflowFlags) { - if (auto iface = dyn_cast(op)) - iface.setOverflowFlags(overflowFlags); -} - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector resultTypes; @@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite( } // Create the operation through state since we don't know its C++ type. - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, targetAttrs); - - setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands, + resultTypes, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b506e12d..24b01259f0499 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -116,15 +116,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( auto llvmNDVectorTy = operands[0].getType(); if (!isa(llvmNDVectorTy)) - return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, - rewriter, overflowFlags); - - auto callback = [op, targetOp, targetAttrs, overflowFlags, + return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr, + typeConverter, rewriter); + auto callback = [op, targetOp, targetAttrs, propertiesAttr, &rewriter](Type llvm1DVectorTy, ValueRange operands) { - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), - operands, llvm1DVectorTy, targetAttrs); - LLVM::detail::setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), + operands, llvm1DVectorTy, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); return newOp->getResult(0); }; diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index de3efc9fe3506..e256915933a71 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -389,8 +389,8 @@ def TruncIExtUIToExtUI : // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr - (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow), - (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow), + (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index b5dcb01d3dc6b..e19586abe442c 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -738,6 +738,22 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) { // ----- +// CHECK-LABEL: @ops_supporting_exact +func.func @ops_supporting_exact(i32, i32) { +^bb0(%arg0: i32, %arg1: i32): +// CHECK: = llvm.ashr exact %arg0, %arg1 : i32 + %0 = arith.shrsi %arg0, %arg1 exact : i32 +// CHECK: = llvm.lshr exact %arg0, %arg1 : i32 + %1 = arith.shrui %arg0, %arg1 exact : i32 +// CHECK: = llvm.sdiv exact %arg0, %arg1 : i32 + %2 = arith.divsi %arg0, %arg1 exact : i32 +// CHECK: = llvm.udiv exact %arg0, %arg1 : i32 + %3 = arith.divui %arg0, %arg1 exact : i32 + return +} + +// ----- + // CHECK-LABEL: func @memref_bitcast // CHECK-SAME: (%[[ARG:.*]]: memref) // CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 2fe0995c9d4df..3ad1530248809 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2958,6 +2958,19 @@ func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 { return %hi : i32 } +// CHECK-LABEL: @truncIShrSIExactToTrunciShrUIExact +// CHECK-SAME: (%[[A:.+]]: i64) +// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64 +// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] exact : i64 +// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32 +// CHECK-NEXT: return %[[TRU]] : i32 +func.func @truncIShrSIExactToTrunciShrUIExact(%a: i64) -> i32 { + %c32 = arith.constant 32: i64 + %sh = arith.shrsi %a, %c32 exact : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + // CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1 // CHECK: arith.shrsi func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 { diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 1e656e84da836..58eadfda17060 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -151,6 +151,12 @@ func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_divui_exact +func.func @test_divui_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.divui %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_divui_tensor func.func @test_divui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.divui %arg0, %arg1 : tensor<8x8xi64> @@ -175,6 +181,12 @@ func.func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_divsi_exact +func.func @test_divsi_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.divsi %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_divsi_tensor func.func @test_divsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.divsi %arg0, %arg1 : tensor<8x8xi64> @@ -391,6 +403,12 @@ func.func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_shrui_exact +func.func @test_shrui_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.shrui %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_shrui_tensor func.func @test_shrui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.shrui %arg0, %arg1 : tensor<8x8xi64> @@ -415,6 +433,12 @@ func.func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_shrsi_exact +func.func @test_shrsi_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.shrsi %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_shrsi_tensor func.func @test_shrsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.shrsi %arg0, %arg1 : tensor<8x8xi64>