Skip to content

Commit f89476d

Browse files
committed
[mlir][IntRange] Poison support in int-range analysis
1 parent e9d71ef commit f89476d

File tree

12 files changed

+277
-39
lines changed

12 files changed

+277
-39
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
4949
// Explicitly depend on "arith" because this pass could create operations in
5050
// `arith` out of thin air in some cases.
5151
let dependentDialects = [
52-
"::mlir::arith::ArithDialect"
52+
"::mlir::arith::ArithDialect",
53+
"::mlir::ub::UBDialect"
5354
];
5455
}
5556

mlir/include/mlir/Dialect/UB/IR/UBOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/IR/Dialect.h"
1414
#include "mlir/IR/OpImplementation.h"
15+
#include "mlir/Interfaces/InferIntRangeInterface.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"
1617

1718
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"

mlir/include/mlir/Dialect/UB/IR/UBOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
#ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
1010
#define MLIR_DIALECT_UB_IR_UBOPS_TD
1111

12-
include "mlir/Interfaces/SideEffectInterfaces.td"
1312
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/Interfaces/InferIntRangeInterface.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
1415

1516
include "UBOpsInterfaces.td"
1617

@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
3940
// PoisonOp
4041
//===----------------------------------------------------------------------===//
4142

42-
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
43+
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
44+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
4345
let summary = "Poisoned constant operation.";
4446
let description = [{
4547
The `poison` operation materializes a compile-time poisoned constant value

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class ConstantIntRanges {
5151
/// The maximum value of an integer when it is interpreted as signed.
5252
const APInt &smax() const;
5353

54+
/// Get the bitwidth of the ranges.
55+
unsigned getBitWidth() const;
56+
5457
/// Return the bitwidth that should be used for integer ranges describing
5558
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
5659
/// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
6265
/// sint_max(width)].
6366
static ConstantIntRanges maxRange(unsigned bitwidth);
6467

68+
/// Create a poisoned range, i.e. a range that represents no valid integer
69+
/// values.
70+
static ConstantIntRanges poison(unsigned bitwidth);
71+
6572
/// Create a `ConstantIntRanges` with a constant value - that is, with the
6673
/// bounds [value, value] for both its signed interpretations.
6774
static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,14 @@ class ConstantIntRanges {
96103
/// value.
97104
std::optional<APInt> getConstantValue() const;
98105

106+
/// Returns true if signed range is poisoned, i.e. no valid signed value
107+
/// can be represented.
108+
bool isSignedPoison() const;
109+
110+
/// Returns true if unsigned range is poisoned, i.e. no valid unsigned value
111+
/// can be represented.
112+
bool isUnsignedPoison() const;
113+
99114
friend raw_ostream &operator<<(raw_ostream &os,
100115
const ConstantIntRanges &range);
101116

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1515
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/UB/IR/UBOps.h"
1718
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1819
#include "mlir/IR/IRMapping.h"
1920
#include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
4647
return inferredRange.getConstantValue();
4748
}
4849

50+
static bool isPoison(DataFlowSolver &solver, Value value) {
51+
auto *maybeInferredRange =
52+
solver.lookupState<IntegerValueRangeLattice>(value);
53+
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
54+
return false;
55+
const ConstantIntRanges &inferredRange =
56+
maybeInferredRange->getValue().getValue();
57+
return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
58+
}
59+
4960
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
5061
Value newVal) {
5162
assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
6374
RewriterBase &rewriter, Value value) {
6475
if (value.use_empty())
6576
return failure();
77+
78+
if (isPoison(solver, value)) {
79+
Value poison =
80+
ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
81+
if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
82+
solver.eraseState(poison);
83+
copyIntegerRange(solver, value, poison);
84+
rewriter.replaceAllUsesWith(value, poison);
85+
return success();
86+
}
87+
6688
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
6789
if (!maybeConstValue.has_value())
6890
return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
131153
return failure();
132154

133155
auto needsReplacing = [&](Value v) {
134-
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
156+
return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
157+
!v.use_empty();
135158
};
136159
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
137160
if (op->getNumRegions() == 0)

mlir/lib/Dialect/UB/IR/UBOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
5959

6060
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
6161

62+
void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
63+
SetIntRangeFn setResultRange) {
64+
unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
65+
setResultRange(getResult(), ConstantIntRanges::poison(width));
66+
}
67+
6268
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
6369

6470
#define GET_ATTRDEF_CLASSES

mlir/lib/Interfaces/InferIntRangeInterface.cpp

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
2828

2929
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
3030

31+
unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
32+
3133
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
3234
type = getElementTypeOrSelf(type);
3335
if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
4244
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
4345
}
4446

47+
ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
48+
if (bitwidth == 0) {
49+
auto zero = APInt::getZero(0);
50+
return {zero, zero, zero, zero};
51+
}
52+
53+
// Poison is represented by an empty range.
54+
auto zero = APInt::getZero(bitwidth);
55+
auto one = zero + 1;
56+
auto onem = zero - 1;
57+
// For i1 the valid unsigned range is [0, 1] and the valid signed range
58+
// is [-1, 0].
59+
return {one, zero, zero, onem};
60+
}
61+
4562
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
4663
return {value, value, value, value};
4764
}
@@ -85,15 +102,37 @@ ConstantIntRanges
85102
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
86103
// "Not an integer" poisons everything and also cannot be fed to comparison
87104
// operators.
88-
if (umin().getBitWidth() == 0)
105+
if (getBitWidth() == 0)
89106
return *this;
90-
if (other.umin().getBitWidth() == 0)
107+
if (other.getBitWidth() == 0)
91108
return other;
92109

93-
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
94-
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
95-
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
96-
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
110+
APInt uminUnion;
111+
APInt umaxUnion;
112+
APInt sminUnion;
113+
APInt smaxUnion;
114+
115+
if (isUnsignedPoison()) {
116+
uminUnion = other.umin();
117+
umaxUnion = other.umax();
118+
} else if (other.isUnsignedPoison()) {
119+
uminUnion = umin();
120+
umaxUnion = umax();
121+
} else {
122+
uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
123+
umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
124+
}
125+
126+
if (isSignedPoison()) {
127+
sminUnion = other.smin();
128+
smaxUnion = other.smax();
129+
} else if (other.isSignedPoison()) {
130+
sminUnion = smin();
131+
smaxUnion = smax();
132+
} else {
133+
sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
134+
smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
135+
}
97136

98137
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99138
}
@@ -102,15 +141,37 @@ ConstantIntRanges
102141
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
103142
// "Not an integer" poisons everything and also cannot be fed to comparison
104143
// operators.
105-
if (umin().getBitWidth() == 0)
144+
if (getBitWidth() == 0)
106145
return *this;
107-
if (other.umin().getBitWidth() == 0)
146+
if (other.getBitWidth() == 0)
108147
return other;
109148

110-
const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
111-
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
112-
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
113-
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
149+
APInt uminIntersect;
150+
APInt umaxIntersect;
151+
APInt sminIntersect;
152+
APInt smaxIntersect;
153+
154+
if (isUnsignedPoison()) {
155+
uminIntersect = umin();
156+
umaxIntersect = umax();
157+
} else if (other.isUnsignedPoison()) {
158+
uminIntersect = other.umin();
159+
umaxIntersect = other.umax();
160+
} else {
161+
uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
162+
umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
163+
}
164+
165+
if (isSignedPoison()) {
166+
sminIntersect = smin();
167+
smaxIntersect = smax();
168+
} else if (other.isSignedPoison()) {
169+
sminIntersect = other.smin();
170+
smaxIntersect = other.smax();
171+
} else {
172+
sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
173+
smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
174+
}
114175

115176
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116177
}
@@ -124,6 +185,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
124185
return std::nullopt;
125186
}
126187

188+
bool ConstantIntRanges::isSignedPoison() const {
189+
return getBitWidth() > 0 && smin().sgt(smax());
190+
}
191+
192+
bool ConstantIntRanges::isUnsignedPoison() const {
193+
return getBitWidth() > 0 && umin().ugt(umax());
194+
}
195+
127196
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
128197
os << "unsigned : [";
129198
range.umin().print(os, /*isSigned*/ false);

0 commit comments

Comments
 (0)