diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index adc27ae6bdafb..993f36f556e87 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -226,7 +226,7 @@ def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> { these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. - This op supports `nuw`/`nsw` overflow flags which stands stand for + This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or `nsw` flags are present, and an unsigned/signed overflow occurs (respectively), the result is poison. @@ -321,7 +321,7 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> { these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. - This op supports `nuw`/`nsw` overflow flags which stands stand for + This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or `nsw` flags are present, and an unsigned/signed overflow occurs (respectively), the result is poison. @@ -367,7 +367,7 @@ def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. - This op supports `nuw`/`nsw` overflow flags which stands stand for + This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or `nsw` flags are present, and an unsigned/signed overflow occurs (respectively), the result is poison. @@ -800,7 +800,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> { operand is greater or equal than the bitwidth of the first operand, then the operation returns poison. - This op supports `nuw`/`nsw` overflow flags which stands stand for + This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or `nsw` flags are present, and an unsigned/signed overflow occurs (respectively), the result is poison. @@ -1271,7 +1271,11 @@ def Arith_ScalingExtFOp // TruncIOp //===----------------------------------------------------------------------===// -def Arith_TruncIOp : Arith_IToICastOp<"trunci"> { +def Arith_TruncIOp : Op, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -1279,17 +1283,37 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> { bit-width must be smaller than the input bit-width (N < M). The top-most (N - M) bits of the input are discarded. + This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned + Wrap" and "No Signed Wrap", respectively. If the nuw keyword is present, + and any of the truncated bits are non-zero, the result is a poison value. + If the nsw keyword is present, and any of the truncated bits are not the + same as the top bit of the truncation result, the result is a poison value. + Example: ```mlir + // Scalar truncation. %1 = arith.constant 21 : i5 // %1 is 0b10101 %2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101 %3 = arith.trunci %1 : i5 to i3 // %3 is 0b101 - %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> + // Vector truncation. + %4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> + + // Scalar truncation with overflow flags. + %5 = arith.trunci %a overflow : i32 to i16 ``` }]; + let arguments = (ins + SignlessFixedWidthIntegerLike:$in, + DefaultValuedAttr:$overflowFlags); + let results = (outs SignlessFixedWidthIntegerLike:$out); + let assemblyFormat = [{ + $in (`overflow` `` $overflowFlags^)? attr-dict + `:` type($in) `to` type($out) + }]; let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index ced18a48766bf..b8e5aa87244fa 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -163,7 +163,8 @@ using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, arith::AttrConverterConstrainedFPToLLVM>; using TruncIOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using UIToFPOpLowering = VectorConvertToLLVMPattern; using XOrIOpLowering = VectorConvertToLLVMPattern; diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 13eb97a910bd4..b61612436eb78 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -378,14 +378,14 @@ def TruncationMatchesShiftAmount : // trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated def TruncIExtSIToExtSI : - Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)), + Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x), $overflow), (Arith_ExtSIOp $x), [(ValueWiderThan $ext, $tr), (ValueWiderThan $tr, $x)]>; // trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated def TruncIExtUIToExtUI : - Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)), + Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x), $overflow), (Arith_ExtUIOp $x), [(ValueWiderThan $ext, $tr), (ValueWiderThan $tr, $x)]>; @@ -393,8 +393,8 @@ def TruncIExtUIToExtUI : // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr - (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))), - (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))), + (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; // trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) @@ -402,7 +402,7 @@ def TruncIShrUIMulIToMulSIExtended : Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp (Arith_MulIOp:$mul (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1), - (ConstantLikeMatcher AnyAttr:$c0))), + (ConstantLikeMatcher AnyAttr:$c0)), $overflow), (Arith_MulSIExtendedOp:$res__1 $x, $y), [(ValuesWithSameType $tr, $x, $y), (ValueWiderThan $mul, $x), @@ -413,7 +413,7 @@ def TruncIShrUIMulIToMulUIExtended : Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp (Arith_MulIOp:$mul (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1), - (ConstantLikeMatcher AnyAttr:$c0))), + (ConstantLikeMatcher AnyAttr:$c0)), $overflow), (Arith_MulUIExtendedOp:$res__1 $x, $y), [(ValuesWithSameType $tr, $x, $y), (ValueWiderThan $mul, $x), diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index e0d974ea74041..83bdbe1f67118 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -731,6 +731,8 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) { %2 = arith.muli %arg0, %arg1 overflow : i64 // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow : i64 %3 = arith.shli %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = llvm.trunc %{{.*}} overflow : i64 to i32 + %4 = arith.trunci %arg0 overflow : i64 to i32 return } diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index f684e02344a51..1e656e84da836 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1159,5 +1159,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) { %2 = arith.muli %arg0, %arg1 overflow : i64 // CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow : i64 %3 = arith.shli %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = arith.trunci %{{.*}} overflow : i64 to i32 + %4 = arith.trunci %arg0 overflow : i64 to i32 return }