diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 2c1276d577a55..7a73a94201f1d 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef argRanges) { /*isSigned=*/false); } +/// Get bitmask of all bits which can change while iterating in +/// [bound.umin(), bound.umax()]. +static APInt getVaryingBitsMask(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); + return APInt::getLowBitsSet(bitwidth, differingBits); +} + ConstantIntRanges mlir::intrange::inferXor(ArrayRef argRanges) { - // TODO: The code below doesn't work for bitwidths > i1. - // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849] - // widenBitwiseBounds will produce: - // lhs: - // 2060639848 01111010110100101101111001101000 - // 2060639851 01111010110100101101111001101011 - // rhs: - // 2060639849 01111010110100101101111001101001 - // 2060639849 01111010110100101101111001101001 - // None of those combinations xor to 0, while intermediate values does. - unsigned width = argRanges[0].umin().getBitWidth(); - if (width > 1) - return ConstantIntRanges::maxRange(width); - - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); + // Construct mask of varying bits for both ranges, xor values and then replace + // masked bits with 0s and 1s to get min and max values respectively. + ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1]; + APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs); + APInt res = lhs.umin() ^ rhs.umin(); + APInt min = res & ~mask; + APInt max = res | mask; + return ConstantIntRanges::fromUnsigned(min, max); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir index 4db846fa4656a..48a3eb20eb7fb 100644 --- a/mlir/test/Dialect/Arith/int-range-interface.mlir +++ b/mlir/test/Dialect/Arith/int-range-interface.mlir @@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) { } // CHECK-LABEL: func @xori -// TODO: xor folding is temporarily disabled -// CHECK-NOT: arith.constant false +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 { %c0 = arith.constant 0 : i64 %c7 = arith.constant 7 : i64