Skip to content

Commit b985a17

Browse files
committed
[mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index
Since a need for it came up dowstream (in proving that loops run at least once), this commit implements the ValueBoundsOpInterface for affine.delinearize_index and affine.linearize_index, using affine map representations of the operations they perform. For reasons that are unclear to me, attempting to provide the addition constraints that can be inferred from setting the outer bounds on a affine.delinearize_index doesn't work (that is, in the test, I know %0#0 is both %arg0 / 15 and < 2 (that is, that %arg0 < 30)) but I can't record this extra constraint. I've left this issue as-is for now.
1 parent b6960e2 commit b985a17

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,66 @@ struct AffineMaxOpInterface
9191
};
9292
};
9393

94+
struct AffineDelinearizeIndexOpInterface
95+
: public ValueBoundsOpInterface::ExternalModel<
96+
AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
97+
void populateBoundsForIndexValue(Operation *rawOp, Value value,
98+
ValueBoundsConstraintSet &cstr) const {
99+
auto op = cast<AffineDelinearizeIndexOp>(rawOp);
100+
auto result = cast<OpResult>(value);
101+
assert(result.getOwner() == rawOp &&
102+
"bounded value isn't a result of this delinearize_index");
103+
unsigned resIdx = result.getResultNumber();
104+
105+
AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
106+
107+
SmallVector<OpFoldResult> basis = op.getPaddedBasis();
108+
AffineExpr divisor = cstr.getExpr(1);
109+
for (OpFoldResult basisElem :
110+
ArrayRef<OpFoldResult>(basis).drop_front(resIdx + 1))
111+
divisor = divisor * cstr.getExpr(basisElem);
112+
113+
auto resBound = cstr.bound(result);
114+
if (resIdx == 0) {
115+
resBound == linearIdx.floorDiv(divisor);
116+
if (!basis.front().isNull())
117+
resBound < cstr.getExpr(basis.front());
118+
return;
119+
}
120+
AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
121+
resBound == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
122+
}
123+
};
124+
125+
struct AffineLinearizeIndexOpInterface
126+
: public ValueBoundsOpInterface::ExternalModel<
127+
AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
128+
void populateBoundsForIndexValue(Operation *rawOp, Value value,
129+
ValueBoundsConstraintSet &cstr) const {
130+
auto op = cast<AffineLinearizeIndexOp>(rawOp);
131+
assert(value == op.getResult() &&
132+
"value isn't the result of this linearize");
133+
134+
AffineExpr bound = cstr.getExpr(0);
135+
AffineExpr stride = cstr.getExpr(1);
136+
SmallVector<OpFoldResult> basis = op.getPaddedBasis();
137+
OperandRange multiIndex = op.getMultiIndex();
138+
for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
139+
unsigned argNum = multiIndex.size() - (revArgNum + 1);
140+
if (argNum == 0)
141+
break;
142+
OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
143+
bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
144+
stride = stride * cstr.getExpr(length);
145+
}
146+
bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
147+
auto resBound = cstr.bound(value);
148+
resBound == bound;
149+
if (op.getDisjoint() && !basis.front().isNull()) {
150+
resBound <= stride *cstr.getExpr(basis.front());
151+
}
152+
}
153+
};
94154
} // namespace
95155
} // namespace mlir
96156

@@ -100,6 +160,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
100160
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
101161
AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
102162
AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
163+
AffineDelinearizeIndexOp::attachInterface<
164+
AffineDelinearizeIndexOpInterface>(*ctx);
165+
AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
166+
*ctx);
103167
});
104168
}
105169

mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,41 @@ func.func @compare_maps(%a: index, %b: index) {
155155
: (index, index, index, index) -> ()
156156
return
157157
}
158+
159+
// -----
160+
161+
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
162+
// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
163+
// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
164+
// CHECK-LABEL: func.func @delinearize_static
165+
// CHECK-SAME: (%[[arg0:.+]]: index)
166+
// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]](}[%[[arg0]]]
167+
// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]](}[%[[arg0]]]
168+
// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
169+
// CHECK: return %[[v1]], %[[v2]], %[[v3]]
170+
func.func @delinearize_static(%arg0: index) -> (index, index, index) {
171+
%c2 = arith.constant 2 : index
172+
%c3 = arith.constant 3 : index
173+
%0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
174+
%1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
175+
%2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
176+
%3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
177+
// TODO: why doesn't this return true? I'm setting the bound.
178+
"test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
179+
// expected-remark @below{{true}}
180+
"test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
181+
return %1, %2, %3 : index, index, index
182+
}
183+
184+
// -----
185+
186+
// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
187+
// CHECK-LABEL: func.func @linearize_static
188+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
189+
// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg0]], %[[arg1]]]
190+
// CHECK: return %[[v1]]
191+
func.func @linearize_static(%arg0: index, %arg1: index) -> index {
192+
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index
193+
%1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
194+
return %1 : index
195+
}

0 commit comments

Comments
 (0)