From 33cfdd9762e5eb2faffc7f4c93012c2c569af94d Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 27 Nov 2024 14:31:32 -0600 Subject: [PATCH] [mlir] Add ValueBoundsOpInterfaceImpl for affine.delinearize_index and affine.linearize_index Signed-off-by: Max Dawkins --- .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp | 65 +++++++++++++++++++ .../value-bounds-op-interface-impl.mlir | 37 +++++++++++ 2 files changed, 102 insertions(+) diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index 82a9fb0d49088..77107fb894bb0 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -49,6 +49,67 @@ struct AffineApplyOpInterface } }; +struct AffineDelinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto delinearizeOp = cast(op); + auto result = cast(value); + int64_t resultIdx = result.getResultNumber(); + assert(result.getOwner() == delinearizeOp && "invalid value"); + + AffineExpr linearIdxExpr = cstr.getExpr(delinearizeOp.getLinearIndex()); + SmallVector basis = delinearizeOp.getMixedBasis(); + SmallVector basisExprs; + AffineExpr modExpr = getAffineConstantExpr(1, op->getContext()); + AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext()); + for (int i = basis.size() - 1; i >= resultIdx; --i) { + AffineExpr basisExpr = cstr.getExpr(basis[i]); + modExpr = modExpr * basisExpr; + if (i > resultIdx) + strideExpr = strideExpr * basisExpr; + } + AffineExpr bound = linearIdxExpr; + if (resultIdx > 0) + bound = bound % modExpr; + if (resultIdx < delinearizeOp->getNumResults()) + bound = bound.floorDiv(strideExpr); + + cstr.bound(value) == bound; + } +}; + +struct AffineLinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto linearizeOp = cast(op); + assert(value == linearizeOp.getResult() && "invalid value"); + + SmallVector basis = linearizeOp.getMixedBasis(); + SmallVector basisExprs = llvm::map_to_vector( + basis, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); }); + basisExprs.push_back(getAffineConstantExpr(1, op->getContext())); + + SmallVector indices(linearizeOp.getMultiIndex()); + SmallVector indexExprs = llvm::map_to_vector( + indices, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); }); + + AffineExpr linearIdxExpr = getAffineConstantExpr(0, op->getContext()); + AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext()); + std::reverse(indexExprs.begin(), indexExprs.end()); + std::reverse(basisExprs.begin(), basisExprs.end()); + for (size_t i = 0; i < indexExprs.size(); ++i) { + strideExpr = strideExpr * basisExprs[i]; + linearIdxExpr = linearIdxExpr + indexExprs[i] * strideExpr; + } + + cstr.bound(value) == linearIdxExpr; + } +}; + struct AffineMinOpInterface : public ValueBoundsOpInterface::ExternalModel { @@ -98,6 +159,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) { AffineApplyOp::attachInterface(*ctx); + AffineDelinearizeIndexOp::attachInterface< + AffineDelinearizeIndexOpInterface>(*ctx); + AffineLinearizeIndexOp::attachInterface( + *ctx); AffineMaxOp::attachInterface(*ctx); AffineMinOp::attachInterface(*ctx); }); diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 935c08aceff54..2184d7fa5074e 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -155,3 +155,40 @@ func.func @compare_maps(%a: index, %b: index) { : (index, index, index, index) -> () return } + +// ----- + +func.func @compare_affine_linearize_index(%a: index, %b: index) { + %0 = affine.linearize_index [%a, %b] by (2, 4) : index + %1 = affine.linearize_index [%a, %b] by (4) : index + // expected-remark @below{{true}} + "test.compare"(%0, %a, %b) {rhs_map = affine_map<(a, b) -> (a * 4 + b)>} + : (index, index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%1, %a, %b) {rhs_map = affine_map<(a, b) -> (a * 4 + b)>} + : (index, index, index) -> () + return +} + +// ----- + +// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 4)> + +// CHECK-LABEL: func @affine_delinearize_index( +// CHECK-SAME: %[[a:.*]]: index +func.func @affine_delinearize_index(%a: index) -> (index, index, index, index) { + %0:2 = affine.delinearize_index %a into (2, 4) : index, index + %1:2 = affine.delinearize_index %a into (4) : index, index + + // CHECK: %[[BOUND0:.+]] = affine.apply #[[$MAP]]()[%[[a]]] + %2 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index) + // CHECK: %[[BOUND1:.+]] = affine.apply #[[$MAP1]]()[%[[a]]] + %3 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index) + // CHECK: %[[BOUND2:.+]] = affine.apply #[[$MAP]]()[%[[a]]] + %4 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index) + // CHECK: %[[BOUND3:.+]] = affine.apply #[[$MAP1]]()[%[[a]]] + %5 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index) + // CHECK: return %[[BOUND0]], %[[BOUND1]], %[[BOUND2]], %[[BOUND3]] + return %2, %3, %4, %5: index, index, index, index +}