Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4569,9 +4569,38 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}

/// Give mixed basis of affine.delinearize_index/linearize_index replace
/// constant SSA values with the constant integer value and returns the
/// new static basis.
static SmallVector<int64_t>
foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
MutableOperandRange mutableDynamicBasis,
ArrayRef<Attribute> dynamicBasis) {
SmallVector<int64_t> staticBasis;
for (OpFoldResult basis : mixedBasis) {
std::optional<int64_t> basisVal = getConstantIntValue(basis);
if (!basisVal)
staticBasis.push_back(ShapedType::kDynamic);
else
staticBasis.push_back(*basisVal);
}

int64_t dynamicBasisIndex = 0;
for (OpFoldResult basis : dynamicBasis) {
if (basis) {
mutableDynamicBasis.erase(dynamicBasisIndex);
} else {
++dynamicBasisIndex;
}
}
return staticBasis;
}

LogicalResult
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
setStaticBasis(foldCstValueToCstAttrBasis(
getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
// If we won't be doing any division or modulo (no basis or the one basis
// element is purely advisory), simply return the input value.
if (getNumResults() == 1) {
Expand Down Expand Up @@ -4789,6 +4818,8 @@ LogicalResult AffineLinearizeIndexOp::verify() {
}

OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
setStaticBasis(foldCstValueToCstAttrBasis(
getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
// No indices linearizes to zero.
if (getMultiIndex().empty())
return IntegerAttr::get(getResult().getType(), 0);
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1946,3 +1946,33 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
return %ret : index
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
(index, index, index) {
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
: index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
// CHECK: return %[[RET]] : index
func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index) {
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c1, 3, %c4) : index
return %1 : index
}
Loading