Skip to content

Commit bd19ae0

Browse files
committed
[mlir][IntRange] Poison support in int-range analysis
1 parent e9d71ef commit bd19ae0

File tree

9 files changed

+196
-32
lines changed

9 files changed

+196
-32
lines changed

mlir/include/mlir/Dialect/UB/IR/UBOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/IR/Dialect.h"
1414
#include "mlir/IR/OpImplementation.h"
15+
#include "mlir/Interfaces/InferIntRangeInterface.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"
1617

1718
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"

mlir/include/mlir/Dialect/UB/IR/UBOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
#ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
1010
#define MLIR_DIALECT_UB_IR_UBOPS_TD
1111

12-
include "mlir/Interfaces/SideEffectInterfaces.td"
1312
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/Interfaces/InferIntRangeInterface.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
1415

1516
include "UBOpsInterfaces.td"
1617

@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
3940
// PoisonOp
4041
//===----------------------------------------------------------------------===//
4142

42-
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
43+
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
44+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
4345
let summary = "Poisoned constant operation.";
4446
let description = [{
4547
The `poison` operation materializes a compile-time poisoned constant value

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ class ConstantIntRanges {
6262
/// sint_max(width)].
6363
static ConstantIntRanges maxRange(unsigned bitwidth);
6464

65+
/// Create a poisoned range, i.e. a range that represents no valid integer
66+
/// values.
67+
static ConstantIntRanges poison(unsigned bitwidth);
68+
6569
/// Create a `ConstantIntRanges` with a constant value - that is, with the
6670
/// bounds [value, value] for both its signed interpretations.
6771
static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +100,14 @@ class ConstantIntRanges {
96100
/// value.
97101
std::optional<APInt> getConstantValue() const;
98102

103+
/// Returns true if signed range is poisoned, i.e. no valid signed value
104+
/// can be represented.
105+
bool isSignedPoison() const;
106+
107+
/// Returns true if unsigned range is poisoned, i.e. no valid unsigned value
108+
/// can be represented.
109+
bool isUnsignedPoison() const;
110+
99111
friend raw_ostream &operator<<(raw_ostream &os,
100112
const ConstantIntRanges &range);
101113

mlir/lib/Dialect/UB/IR/UBOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
5959

6060
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
6161

62+
void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
63+
SetIntRangeFn setResultRange) {
64+
unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
65+
setResultRange(getResult(), ConstantIntRanges::poison(width));
66+
}
67+
6268
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
6369

6470
#define GET_ATTRDEF_CLASSES

mlir/lib/Interfaces/InferIntRangeInterface.cpp

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
4242
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
4343
}
4444

45+
ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
46+
if (bitwidth == 0)
47+
return maxRange(0);
48+
49+
// Poison is represented by an empty range.
50+
auto max = APInt::getZero(bitwidth);
51+
auto min = max + 1;
52+
return {min, max, min, max};
53+
}
54+
4555
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
4656
return {value, value, value, value};
4757
}
@@ -90,10 +100,32 @@ ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
90100
if (other.umin().getBitWidth() == 0)
91101
return other;
92102

93-
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
94-
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
95-
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
96-
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
103+
APInt uminUnion;
104+
APInt umaxUnion;
105+
APInt sminUnion;
106+
APInt smaxUnion;
107+
108+
if (isUnsignedPoison()) {
109+
uminUnion = other.umin();
110+
umaxUnion = other.umax();
111+
} else if (other.isUnsignedPoison()) {
112+
uminUnion = umin();
113+
umaxUnion = umax();
114+
} else {
115+
uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
116+
umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
117+
}
118+
119+
if (isSignedPoison()) {
120+
sminUnion = other.smin();
121+
smaxUnion = other.smax();
122+
} else if (other.isSignedPoison()) {
123+
sminUnion = smin();
124+
smaxUnion = smax();
125+
} else {
126+
sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
127+
smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
128+
}
97129

98130
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99131
}
@@ -107,10 +139,32 @@ ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
107139
if (other.umin().getBitWidth() == 0)
108140
return other;
109141

110-
const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
111-
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
112-
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
113-
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
142+
APInt uminIntersect;
143+
APInt umaxIntersect;
144+
APInt sminIntersect;
145+
APInt smaxIntersect;
146+
147+
if (isUnsignedPoison()) {
148+
uminIntersect = umin();
149+
umaxIntersect = umax();
150+
} else if (other.isUnsignedPoison()) {
151+
uminIntersect = other.umin();
152+
umaxIntersect = other.umax();
153+
} else {
154+
uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
155+
umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
156+
}
157+
158+
if (isSignedPoison()) {
159+
sminIntersect = smin();
160+
smaxIntersect = smax();
161+
} else if (other.isSignedPoison()) {
162+
sminIntersect = other.smin();
163+
smaxIntersect = other.smax();
164+
} else {
165+
sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
166+
smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
167+
}
114168

115169
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116170
}
@@ -124,6 +178,10 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
124178
return std::nullopt;
125179
}
126180

181+
bool ConstantIntRanges::isSignedPoison() const { return smin().sgt(smax()); }
182+
183+
bool ConstantIntRanges::isUnsignedPoison() const { return umin().ugt(umax()); }
184+
127185
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
128186
os << "unsigned : [";
129187
range.umin().print(os, /*isSigned*/ false);

0 commit comments

Comments
 (0)