Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4569,9 +4569,49 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}

/// Given mixed basis of affine.delinearize_index/linearize_index replace
/// constant SSA values with the constant integer value and return the new
/// static basis. In case no such candidate for replacement exists, this utility
/// returns std::nullopt.
static std::optional<SmallVector<int64_t>>
foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
MutableOperandRange mutableDynamicBasis,
ArrayRef<Attribute> dynamicBasis) {
int64_t dynamicBasisIndex = 0;
for (OpFoldResult basis : dynamicBasis) {
if (basis) {
mutableDynamicBasis.erase(dynamicBasisIndex);
} else {
++dynamicBasisIndex;
}
}

// No constant SSA value exists.
if (dynamicBasisIndex == dynamicBasis.size())
return std::nullopt;

SmallVector<int64_t> staticBasis;
for (OpFoldResult basis : mixedBasis) {
std::optional<int64_t> basisVal = getConstantIntValue(basis);
if (!basisVal)
staticBasis.push_back(ShapedType::kDynamic);
else
staticBasis.push_back(*basisVal);
}

return staticBasis;
}

LogicalResult
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
std::optional<SmallVector<int64_t>> maybeStaticBasis =
foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
adaptor.getDynamicBasis());
if (maybeStaticBasis) {
setStaticBasis(*maybeStaticBasis);
return success();
}
// If we won't be doing any division or modulo (no basis or the one basis
// element is purely advisory), simply return the input value.
if (getNumResults() == 1) {
Expand Down Expand Up @@ -4789,6 +4829,13 @@ LogicalResult AffineLinearizeIndexOp::verify() {
}

OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
std::optional<SmallVector<int64_t>> maybeStaticBasis =
foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
adaptor.getDynamicBasis());
if (maybeStaticBasis) {
setStaticBasis(*maybeStaticBasis);
return getResult();
}
// No indices linearizes to zero.
if (getMultiIndex().empty())
return IntegerAttr::get(getResult().getType(), 0);
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1946,3 +1946,32 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
return %ret : index
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2 : index, index, index
func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
(index, index, index) {
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c2)
: index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 3, 4) : index
// CHECK: return %[[RET]] : index
func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index) {
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c2, 3, %c4) : index
return %0 : index
}