diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 67d7da622a355..e889ae638986b 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4569,9 +4569,49 @@ LogicalResult AffineDelinearizeIndexOp::verify() { return success(); } +/// Given mixed basis of affine.delinearize_index/linearize_index replace +/// constant SSA values with the constant integer value and return the new +/// static basis. In case no such candidate for replacement exists, this utility +/// returns std::nullopt. +static std::optional> +foldCstValueToCstAttrBasis(ArrayRef mixedBasis, + MutableOperandRange mutableDynamicBasis, + ArrayRef dynamicBasis) { + int64_t dynamicBasisIndex = 0; + for (OpFoldResult basis : dynamicBasis) { + if (basis) { + mutableDynamicBasis.erase(dynamicBasisIndex); + } else { + ++dynamicBasisIndex; + } + } + + // No constant SSA value exists. + if (dynamicBasisIndex == dynamicBasis.size()) + return std::nullopt; + + SmallVector staticBasis; + for (OpFoldResult basis : mixedBasis) { + std::optional basisVal = getConstantIntValue(basis); + if (!basisVal) + staticBasis.push_back(ShapedType::kDynamic); + else + staticBasis.push_back(*basisVal); + } + + return staticBasis; +} + LogicalResult AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor, SmallVectorImpl &result) { + std::optional> maybeStaticBasis = + foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(), + adaptor.getDynamicBasis()); + if (maybeStaticBasis) { + setStaticBasis(*maybeStaticBasis); + return success(); + } // If we won't be doing any division or modulo (no basis or the one basis // element is purely advisory), simply return the input value. if (getNumResults() == 1) { @@ -4789,6 +4829,13 @@ LogicalResult AffineLinearizeIndexOp::verify() { } OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) { + std::optional> maybeStaticBasis = + foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(), + adaptor.getDynamicBasis()); + if (maybeStaticBasis) { + setStaticBasis(*maybeStaticBasis); + return getResult(); + } // No indices linearizes to zero. if (getMultiIndex().empty()) return IntegerAttr::get(getResult().getType(), 0); diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index 5384977151b47..b747178c5b1a9 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1946,3 +1946,32 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind return %ret : index } +// ----- + +// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index +// CHECK-SAME: (%[[ARG0:.*]]: index) +// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index +// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2 : index, index, index +func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) -> + (index, index, index) { + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c2) + : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// ----- + +// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 3, 4) : index +// CHECK: return %[[RET]] : index +func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) -> + (index) { + %c4 = arith.constant 4 : index + %c2 = arith.constant 2 : index + %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c2, 3, %c4) : index + return %0 : index +}