Skip to content

Commit d83148f

Browse files
[MLIR][Affine] Update ::fold() to have constant basis attr for affine.delinearize_index/linearize_index (#117572)
-- This commit updates `::fold()` to have constant(CST) attribute for affine.delinearize_index/linearize_index op's basis wherever applicable. -- Essentially the code checks if the mixed basis OpFoldResult set contains any constant SSA value and converts it to a constant integer instead. Signed-off-by: Abhishek Varma <[email protected]>
1 parent b2d3cb1 commit d83148f

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4569,9 +4569,49 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
45694569
return success();
45704570
}
45714571

4572+
/// Given mixed basis of affine.delinearize_index/linearize_index replace
4573+
/// constant SSA values with the constant integer value and return the new
4574+
/// static basis. In case no such candidate for replacement exists, this utility
4575+
/// returns std::nullopt.
4576+
static std::optional<SmallVector<int64_t>>
4577+
foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
4578+
MutableOperandRange mutableDynamicBasis,
4579+
ArrayRef<Attribute> dynamicBasis) {
4580+
int64_t dynamicBasisIndex = 0;
4581+
for (OpFoldResult basis : dynamicBasis) {
4582+
if (basis) {
4583+
mutableDynamicBasis.erase(dynamicBasisIndex);
4584+
} else {
4585+
++dynamicBasisIndex;
4586+
}
4587+
}
4588+
4589+
// No constant SSA value exists.
4590+
if (dynamicBasisIndex == dynamicBasis.size())
4591+
return std::nullopt;
4592+
4593+
SmallVector<int64_t> staticBasis;
4594+
for (OpFoldResult basis : mixedBasis) {
4595+
std::optional<int64_t> basisVal = getConstantIntValue(basis);
4596+
if (!basisVal)
4597+
staticBasis.push_back(ShapedType::kDynamic);
4598+
else
4599+
staticBasis.push_back(*basisVal);
4600+
}
4601+
4602+
return staticBasis;
4603+
}
4604+
45724605
LogicalResult
45734606
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
45744607
SmallVectorImpl<OpFoldResult> &result) {
4608+
std::optional<SmallVector<int64_t>> maybeStaticBasis =
4609+
foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4610+
adaptor.getDynamicBasis());
4611+
if (maybeStaticBasis) {
4612+
setStaticBasis(*maybeStaticBasis);
4613+
return success();
4614+
}
45754615
// If we won't be doing any division or modulo (no basis or the one basis
45764616
// element is purely advisory), simply return the input value.
45774617
if (getNumResults() == 1) {
@@ -4875,6 +4915,13 @@ LogicalResult AffineLinearizeIndexOp::verify() {
48754915
}
48764916

48774917
OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4918+
std::optional<SmallVector<int64_t>> maybeStaticBasis =
4919+
foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4920+
adaptor.getDynamicBasis());
4921+
if (maybeStaticBasis) {
4922+
setStaticBasis(*maybeStaticBasis);
4923+
return getResult();
4924+
}
48784925
// No indices linearizes to zero.
48794926
if (getMultiIndex().empty())
48804927
return IntegerAttr::get(getResult().getType(), 0);

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,3 +2012,32 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
20122012
return %ret : index
20132013
}
20142014

2015+
// -----
2016+
2017+
// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
2018+
// CHECK-SAME: (%[[ARG0:.*]]: index)
2019+
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
2020+
// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2 : index, index, index
2021+
func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
2022+
(index, index, index) {
2023+
%c4 = arith.constant 4 : index
2024+
%c3 = arith.constant 3 : index
2025+
%c2 = arith.constant 2 : index
2026+
%0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c2)
2027+
: index, index, index
2028+
return %0#0, %0#1, %0#2 : index, index, index
2029+
}
2030+
2031+
// -----
2032+
2033+
// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
2034+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
2035+
// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 3, 4) : index
2036+
// CHECK: return %[[RET]] : index
2037+
func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
2038+
(index) {
2039+
%c4 = arith.constant 4 : index
2040+
%c2 = arith.constant 2 : index
2041+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c2, 3, %c4) : index
2042+
return %0 : index
2043+
}

0 commit comments

Comments
 (0)