diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 0b078966aeb85..82b0bc193bfb7 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include "AffineExprDetail.h" @@ -18,9 +20,8 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/MathExtras.h" -#include -#include using namespace mlir; using namespace mlir::detail; @@ -349,6 +350,82 @@ unsigned AffineDimExpr::getPosition() const { return static_cast(expr)->position; } +/// A manually managed stack used to convert recursive function calls into +/// looping utility classes during the access tree structure process. This has +/// two benefits: one is to access the current stack, and the other is to avoid +/// stack explosion when the recursion depth is too deep. Typically, recursive +/// calls take the form of the following: +/// push node +/// visit tree node +/// push node->left_node +/// ... +/// pop left_node +/// check result and do something +/// push node->right_node +/// pop right_node +/// pop node +/// ... +/// This form can be converted into the following form: +/// push node +/// visit tree node +/// push node->left_node +/// ... +/// pop left_node and do { +/// check result and do something +/// push node->right_node +/// pop right_node and do { pop node } +/// } +/// ... +/// so we need to perform some operations +/// after an element is pushed out of the stack. We use the `scope_exit` +/// structure to invoke these operations. +template +class CallStack { +public: + using value_type = + std::tuple>>; + CallStack(ArgsT... args) { + stack_.emplace_back(args..., []() {}); + } + + /// Push the parameters into the stack and record the operation to be executed + /// when the node access ends. By default, the previous stack element will pop + /// up. If you need to check the result of the current push to the stack, you + /// need to pass in a function and manually perform the push operation after + /// the function ends. + void pushArgs(ArgsT... args, const std::function &onExit = {}) { + if (onExit) + stack_.emplace_back(args..., onExit); + else + stack_.emplace_back(args..., [this]() { pop(); }); + } + + RetT getResult() const { return value_; } + + void returnValue(RetT value) { + value_ = value; + pop(); + } + + value_type &top() { return stack_.back(); } + + auto begin() const { return stack_.begin(); } + + auto end() const { return stack_.end(); } + + bool empty() const { return stack_.empty(); } + +private: + /// Note: We must move the top element of the stack and then perform the + /// stack pop operation. If we directly pop the stack, the `scope_exit` may + /// modify the stack, which may cause the program on the Windows platform to + /// crash, but it works normally on Ubuntu.git + void pop() { value_type _(stack_.pop_back_val()); } + + SmallVector stack_; + RetT value_; +}; + /// Returns true if the expression is divisible by the given symbol with /// position `symbolPos`. The argument `opKind` specifies here what kind of /// division or mod operation called this division. It helps in implementing the @@ -362,54 +439,101 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || opKind == AffineExprKind::CeilDiv) && "unexpected opKind"); - switch (expr.getKind()) { - case AffineExprKind::Constant: - return cast(expr).getValue() == 0; - case AffineExprKind::DimId: - return false; - case AffineExprKind::SymbolId: - return (cast(expr).getPosition() == symbolPos); - // Checks divisibility by the given symbol for both operands. - case AffineExprKind::Add: { - AffineBinaryOpExpr binaryExpr = cast(expr); - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); - } - // Checks divisibility by the given symbol for both operands. Consider the - // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, - // this is a division by s1 and both the operands of modulo are divisible by - // s1 but it is not divisible by s1 always. The third argument is - // `AffineExprKind::Mod` for this reason. - case AffineExprKind::Mod: { - AffineBinaryOpExpr binaryExpr = cast(expr); - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, - AffineExprKind::Mod) && - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, - AffineExprKind::Mod); - } - // Checks if any of the operand divisible by the given symbol. - case AffineExprKind::Mul: { - AffineBinaryOpExpr binaryExpr = cast(expr); - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); - } - // Floordiv and ceildiv are divisible by the given symbol when the first - // operand is divisible, and the affine expression kind of the argument expr - // is same as the argument `opKind`. This can be inferred from commutative - // property of floordiv and ceildiv operations and are as follow: - // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 - // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 - // It will fail if operations are not same. For example: - // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. - case AffineExprKind::FloorDiv: - case AffineExprKind::CeilDiv: { - AffineBinaryOpExpr binaryExpr = cast(expr); - if (opKind != expr.getKind()) - return false; - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); - } + CallStack stack(expr, symbolPos, + opKind); + + while (!stack.empty()) { + AffineExpr expr = std::get<0>(stack.top()); + unsigned symbolPos = std::get<1>(stack.top()); + AffineExprKind opKind = std::get<2>(stack.top()); + + switch (expr.getKind()) { + case AffineExprKind::Constant: { + // Note: Assignment must occur before pop, which will affect whether it + // enters other execution branches. + stack.returnValue(cast(expr).getValue() == 0); + break; + } + case AffineExprKind::DimId: { + stack.returnValue(false); + break; + } + case AffineExprKind::SymbolId: { + stack.returnValue(cast(expr).getPosition() == + symbolPos); + break; + } + // Checks divisibility by the given symbol for both operands. + case AffineExprKind::Add: { + AffineBinaryOpExpr binaryExpr = cast(expr); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind, + [&stack, binaryExpr, symbolPos, opKind]() { + if (stack.getResult()) + stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind); + else + stack.returnValue(stack.getResult()); + }); + break; + } + // Checks divisibility by the given symbol for both operands. Consider the + // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv + // s1`, this is a division by s1 and both the operands of modulo are + // divisible by s1 but it is not divisible by s1 always. The third argument + // is `AffineExprKind::Mod` for this reason. + case AffineExprKind::Mod: { + AffineBinaryOpExpr binaryExpr = cast(expr); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, + [&stack, binaryExpr, symbolPos]() { + if (stack.getResult()) + stack.pushArgs(binaryExpr.getRHS(), symbolPos, + AffineExprKind::Mod); + else + stack.returnValue(stack.getResult()); + }); + break; + } + // Checks if any of the operand divisible by the given symbol. + case AffineExprKind::Mul: { + AffineBinaryOpExpr binaryExpr = cast(expr); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind, + [&stack, binaryExpr, symbolPos, opKind]() { + if (!stack.getResult()) + stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind); + else + stack.returnValue(stack.getResult()); + }); + break; + } + // Floordiv and ceildiv are divisible by the given symbol when the first + // operand is divisible, and the affine expression kind of the argument expr + // is same as the argument `opKind`. This can be inferred from commutative + // property of floordiv and ceildiv operations and are as follow: + // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 + // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 + // It will fail 1.if operations are not same. For example: + // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a + // multiplication operation in the expression. For example: + // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified. + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + AffineBinaryOpExpr binaryExpr = cast(expr); + if (opKind != expr.getKind()) { + stack.returnValue(false); + break; + } + if (llvm::any_of(stack, [](auto &it) { + return std::get<0>(it).getKind() == AffineExprKind::Mul; + })) { + stack.returnValue(false); + break; + } + stack.pushArgs(binaryExpr.getLHS(), symbolPos, expr.getKind()); + break; + } + llvm_unreachable("Unknown AffineExpr"); + } } - llvm_unreachable("Unknown AffineExpr"); + return stack.getResult(); } /// Divides the given expression by the given symbol at position `symbolPos`. It diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir index 92d3d86bc9306..d1f34f20fa5da 100644 --- a/mlir/test/Dialect/Affine/simplify-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-structures.mlir @@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index { } // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv. -// CHECK-LABEL: func @semiaffine_composite_floor -func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index { +// CHECK-LABEL: func @semiaffine_composite_ceildiv +func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1] + // CHECK: %[[CST:.*]] = arith.constant 43 + return %a : index +} + +// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation. +// CHECK-LABEL: func @semiaffine_composite_ceildiv +func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index { %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1] - // CHECK: %[[CST:.*]] = arith.constant 47 + // CHECK-NOT: arith.constant + return %a : index +} + +// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation +// CHECK-LABEL: func @semiaffine_composite_floordiv +func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1] + // CHECK-NOT: arith.constant return %a : index }