Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
/*isSigned=*/false);
}

/// Get bitmask of all bits which can change while iterating in
/// [bound.umin(), bound.umax()].
static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
APInt leftVal = bound.umin(), rightVal = bound.umax();
unsigned bitwidth = leftVal.getBitWidth();
unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
return APInt::getLowBitsSet(bitwidth, differingBits);
}

ConstantIntRanges
mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
// TODO: The code below doesn't work for bitwidths > i1.
// For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
// widenBitwiseBounds will produce:
// lhs:
// 2060639848 01111010110100101101111001101000
// 2060639851 01111010110100101101111001101011
// rhs:
// 2060639849 01111010110100101101111001101001
// 2060639849 01111010110100101101111001101001
// None of those combinations xor to 0, while intermediate values does.
unsigned width = argRanges[0].umin().getBitWidth();
if (width > 1)
return ConstantIntRanges::maxRange(width);

auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
return a ^ b;
};
return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
/*isSigned=*/false);
// Construct mask of varying bits for both ranges, xor values and then replace
// masked bits with 0s and 1s to get min and max values respectively.
ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
APInt res = lhs.umin() ^ rhs.umin();
APInt min = res & ~mask;
APInt max = res | mask;
return ConstantIntRanges::fromUnsigned(min, max);
}

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Arith/int-range-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) {
}

// CHECK-LABEL: func @xori
// TODO: xor folding is temporarily disabled
// CHECK-NOT: arith.constant false
// CHECK: %[[false:.*]] = arith.constant false
// CHECK: return %[[false]]
func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
%c0 = arith.constant 0 : i64
%c7 = arith.constant 7 : i64
Expand Down
Loading