Skip to content

Commit bc195a7

Browse files
committed
[mlir][IntRange] Poison support in int-range analysis
1 parent e9d71ef commit bc195a7

File tree

9 files changed

+216
-36
lines changed

9 files changed

+216
-36
lines changed

mlir/include/mlir/Dialect/UB/IR/UBOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/IR/Dialect.h"
1414
#include "mlir/IR/OpImplementation.h"
15+
#include "mlir/Interfaces/InferIntRangeInterface.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"
1617

1718
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"

mlir/include/mlir/Dialect/UB/IR/UBOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
#ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
1010
#define MLIR_DIALECT_UB_IR_UBOPS_TD
1111

12-
include "mlir/Interfaces/SideEffectInterfaces.td"
1312
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/Interfaces/InferIntRangeInterface.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
1415

1516
include "UBOpsInterfaces.td"
1617

@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
3940
// PoisonOp
4041
//===----------------------------------------------------------------------===//
4142

42-
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
43+
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
44+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
4345
let summary = "Poisoned constant operation.";
4446
let description = [{
4547
The `poison` operation materializes a compile-time poisoned constant value

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class ConstantIntRanges {
5151
/// The maximum value of an integer when it is interpreted as signed.
5252
const APInt &smax() const;
5353

54+
/// Get the bitwidth of the ranges.
55+
unsigned getBitWidth() const;
56+
5457
/// Return the bitwidth that should be used for integer ranges describing
5558
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
5659
/// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
6265
/// sint_max(width)].
6366
static ConstantIntRanges maxRange(unsigned bitwidth);
6467

68+
/// Create a poisoned range, i.e. a range that represents no valid integer
69+
/// values.
70+
static ConstantIntRanges poison(unsigned bitwidth);
71+
6572
/// Create a `ConstantIntRanges` with a constant value - that is, with the
6673
/// bounds [value, value] for both its signed interpretations.
6774
static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,14 @@ class ConstantIntRanges {
96103
/// value.
97104
std::optional<APInt> getConstantValue() const;
98105

106+
/// Returns true if signed range is poisoned, i.e. no valid signed value
107+
/// can be represented.
108+
bool isSignedPoison() const;
109+
110+
/// Returns true if unsigned range is poisoned, i.e. no valid unsigned value
111+
/// can be represented.
112+
bool isUnsignedPoison() const;
113+
99114
friend raw_ostream &operator<<(raw_ostream &os,
100115
const ConstantIntRanges &range);
101116

mlir/lib/Dialect/UB/IR/UBOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
5959

6060
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
6161

62+
void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
63+
SetIntRangeFn setResultRange) {
64+
unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
65+
setResultRange(getResult(), ConstantIntRanges::poison(width));
66+
}
67+
6268
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
6369

6470
#define GET_ATTRDEF_CLASSES

mlir/lib/Interfaces/InferIntRangeInterface.cpp

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
2828

2929
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
3030

31+
unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
32+
3133
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
3234
type = getElementTypeOrSelf(type);
3335
if (type.isIndex())
@@ -42,6 +44,16 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
4244
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
4345
}
4446

47+
ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
48+
if (bitwidth == 0)
49+
return maxRange(0);
50+
51+
// Poison is represented by an empty range.
52+
auto max = APInt::getZero(bitwidth);
53+
auto min = max + 1;
54+
return {min, max, min, max};
55+
}
56+
4557
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
4658
return {value, value, value, value};
4759
}
@@ -85,15 +97,37 @@ ConstantIntRanges
8597
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
8698
// "Not an integer" poisons everything and also cannot be fed to comparison
8799
// operators.
88-
if (umin().getBitWidth() == 0)
100+
if (getBitWidth() == 0)
89101
return *this;
90-
if (other.umin().getBitWidth() == 0)
102+
if (other.getBitWidth() == 0)
91103
return other;
92104

93-
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
94-
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
95-
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
96-
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
105+
APInt uminUnion;
106+
APInt umaxUnion;
107+
APInt sminUnion;
108+
APInt smaxUnion;
109+
110+
if (isUnsignedPoison()) {
111+
uminUnion = other.umin();
112+
umaxUnion = other.umax();
113+
} else if (other.isUnsignedPoison()) {
114+
uminUnion = umin();
115+
umaxUnion = umax();
116+
} else {
117+
uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
118+
umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
119+
}
120+
121+
if (isSignedPoison()) {
122+
sminUnion = other.smin();
123+
smaxUnion = other.smax();
124+
} else if (other.isSignedPoison()) {
125+
sminUnion = smin();
126+
smaxUnion = smax();
127+
} else {
128+
sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
129+
smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
130+
}
97131

98132
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99133
}
@@ -102,15 +136,37 @@ ConstantIntRanges
102136
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
103137
// "Not an integer" poisons everything and also cannot be fed to comparison
104138
// operators.
105-
if (umin().getBitWidth() == 0)
139+
if (getBitWidth() == 0)
106140
return *this;
107-
if (other.umin().getBitWidth() == 0)
141+
if (other.getBitWidth() == 0)
108142
return other;
109143

110-
const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
111-
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
112-
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
113-
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
144+
APInt uminIntersect;
145+
APInt umaxIntersect;
146+
APInt sminIntersect;
147+
APInt smaxIntersect;
148+
149+
if (isUnsignedPoison()) {
150+
uminIntersect = umin();
151+
umaxIntersect = umax();
152+
} else if (other.isUnsignedPoison()) {
153+
uminIntersect = other.umin();
154+
umaxIntersect = other.umax();
155+
} else {
156+
uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
157+
umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
158+
}
159+
160+
if (isSignedPoison()) {
161+
sminIntersect = smin();
162+
smaxIntersect = smax();
163+
} else if (other.isSignedPoison()) {
164+
sminIntersect = other.smin();
165+
smaxIntersect = other.smax();
166+
} else {
167+
sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
168+
smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
169+
}
114170

115171
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116172
}
@@ -124,6 +180,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
124180
return std::nullopt;
125181
}
126182

183+
bool ConstantIntRanges::isSignedPoison() const {
184+
return getBitWidth() == 0 || smin().sgt(smax());
185+
}
186+
187+
bool ConstantIntRanges::isUnsignedPoison() const {
188+
return getBitWidth() == 0 || umin().ugt(umax());
189+
}
190+
127191
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
128192
os << "unsigned : [";
129193
range.umin().print(os, /*isSigned*/ false);

0 commit comments

Comments
 (0)