From 185863695e191a8dce6b37f38bd2dba086da08d3 Mon Sep 17 00:00:00 2001 From: lipracer Date: Fri, 20 Sep 2024 14:40:31 +0800 Subject: [PATCH 1/5] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression not satisfying commutative Fixes https://github.com/llvm/llvm-project/issues/107508 --- mlir/lib/IR/AffineExpr.cpp | 164 ++++++++++++------ .../Dialect/Affine/simplify-structures.mlir | 22 ++- 2 files changed, 134 insertions(+), 52 deletions(-) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 0b078966aeb85..cc8c4c21b96ce 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include "AffineExprDetail.h" @@ -18,9 +20,8 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/MathExtras.h" -#include -#include using namespace mlir; using namespace mlir::detail; @@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || opKind == AffineExprKind::CeilDiv) && "unexpected opKind"); - switch (expr.getKind()) { - case AffineExprKind::Constant: - return cast(expr).getValue() == 0; - case AffineExprKind::DimId: - return false; - case AffineExprKind::SymbolId: - return (cast(expr).getPosition() == symbolPos); - // Checks divisibility by the given symbol for both operands. - case AffineExprKind::Add: { - AffineBinaryOpExpr binaryExpr = cast(expr); - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); - } - // Checks divisibility by the given symbol for both operands. Consider the - // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, - // this is a division by s1 and both the operands of modulo are divisible by - // s1 but it is not divisible by s1 always. The third argument is - // `AffineExprKind::Mod` for this reason. - case AffineExprKind::Mod: { - AffineBinaryOpExpr binaryExpr = cast(expr); - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, - AffineExprKind::Mod) && - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, - AffineExprKind::Mod); - } - // Checks if any of the operand divisible by the given symbol. - case AffineExprKind::Mul: { - AffineBinaryOpExpr binaryExpr = cast(expr); - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); - } - // Floordiv and ceildiv are divisible by the given symbol when the first - // operand is divisible, and the affine expression kind of the argument expr - // is same as the argument `opKind`. This can be inferred from commutative - // property of floordiv and ceildiv operations and are as follow: - // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 - // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 - // It will fail if operations are not same. For example: - // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. - case AffineExprKind::FloorDiv: - case AffineExprKind::CeilDiv: { - AffineBinaryOpExpr binaryExpr = cast(expr); - if (opKind != expr.getKind()) - return false; - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); - } + std::vector>>> + stack; + stack.emplace_back(expr, symbolPos, opKind, []() {}); + bool result = false; + + while (!stack.empty()) { + AffineExpr expr = std::get<0>(stack.back()); + unsigned symbolPos = std::get<1>(stack.back()); + AffineExprKind opKind = std::get<2>(stack.back()); + + switch (expr.getKind()) { + case AffineExprKind::Constant: { + // Note: Assignment must occur before pop, which will affect whether it + // enters other execution branches. + result = cast(expr).getValue() == 0; + stack.pop_back(); + break; + } + case AffineExprKind::DimId: { + result = false; + stack.pop_back(); + break; + } + case AffineExprKind::SymbolId: { + result = cast(expr).getPosition() == symbolPos; + stack.pop_back(); + break; + } + // Checks divisibility by the given symbol for both operands. + case AffineExprKind::Add: { + AffineBinaryOpExpr binaryExpr = cast(expr); + stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind, + [&stack, &result, binaryExpr, symbolPos, opKind]() { + if (result) { + stack.emplace_back( + binaryExpr.getRHS(), symbolPos, opKind, + [&stack]() { stack.pop_back(); }); + } else { + stack.pop_back(); + } + }); + break; + } + // Checks divisibility by the given symbol for both operands. Consider the + // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv + // s1`, this is a division by s1 and both the operands of modulo are + // divisible by s1 but it is not divisible by s1 always. The third argument + // is `AffineExprKind::Mod` for this reason. + case AffineExprKind::Mod: { + AffineBinaryOpExpr binaryExpr = cast(expr); + stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, + [&stack, &result, binaryExpr, symbolPos, opKind]() { + if (result) { + stack.emplace_back( + binaryExpr.getRHS(), symbolPos, + AffineExprKind::Mod, + [&stack]() { stack.pop_back(); }); + } else { + stack.pop_back(); + } + }); + break; + } + // Checks if any of the operand divisible by the given symbol. + case AffineExprKind::Mul: { + AffineBinaryOpExpr binaryExpr = cast(expr); + stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind, + [&stack, &result, binaryExpr, symbolPos, opKind]() { + if (!result) { + stack.emplace_back( + binaryExpr.getRHS(), symbolPos, opKind, + [&stack]() { stack.pop_back(); }); + } else { + stack.pop_back(); + } + }); + break; + } + // Floordiv and ceildiv are divisible by the given symbol when the first + // operand is divisible, and the affine expression kind of the argument expr + // is same as the argument `opKind`. This can be inferred from commutative + // property of floordiv and ceildiv operations and are as follow: + // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 + // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 + // It will fail 1.if operations are not same. For example: + // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a + // multiplication operation in the expression. For example: + // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified. + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + AffineBinaryOpExpr binaryExpr = cast(expr); + if (opKind != expr.getKind()) { + result = false; + stack.pop_back(); + break; + } + if (llvm::any_of(stack, [](auto &it) { + return std::get<0>(it).getKind() == AffineExprKind::Mul; + })) { + result = false; + stack.pop_back(); + break; + } + + stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(), + [&stack]() { stack.pop_back(); }); + break; + } + llvm_unreachable("Unknown AffineExpr"); + } } - llvm_unreachable("Unknown AffineExpr"); + return result; } /// Divides the given expression by the given symbol at position `symbolPos`. It diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir index 92d3d86bc9306..d1f34f20fa5da 100644 --- a/mlir/test/Dialect/Affine/simplify-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-structures.mlir @@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index { } // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv. -// CHECK-LABEL: func @semiaffine_composite_floor -func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index { +// CHECK-LABEL: func @semiaffine_composite_ceildiv +func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1] + // CHECK: %[[CST:.*]] = arith.constant 43 + return %a : index +} + +// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation. +// CHECK-LABEL: func @semiaffine_composite_ceildiv +func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index { %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1] - // CHECK: %[[CST:.*]] = arith.constant 47 + // CHECK-NOT: arith.constant + return %a : index +} + +// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation +// CHECK-LABEL: func @semiaffine_composite_floordiv +func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1] + // CHECK-NOT: arith.constant return %a : index } From c75822467ae32398e586cbe8badb1dfc1d89e386 Mon Sep 17 00:00:00 2001 From: lipracer Date: Fri, 20 Sep 2024 18:51:51 +0800 Subject: [PATCH 2/5] refine --- mlir/lib/IR/AffineExpr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index cc8c4c21b96ce..05ce0937fa9db 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -415,7 +415,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, case AffineExprKind::Mod: { AffineBinaryOpExpr binaryExpr = cast(expr); stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, - [&stack, &result, binaryExpr, symbolPos, opKind]() { + [&stack, &result, binaryExpr, symbolPos]() { if (result) { stack.emplace_back( binaryExpr.getRHS(), symbolPos, From dee51cdf83bcf865c76ea792a874314b7b30f30e Mon Sep 17 00:00:00 2001 From: lipracer Date: Tue, 24 Sep 2024 16:52:02 +0800 Subject: [PATCH 3/5] refine --- mlir/lib/IR/AffineExpr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 05ce0937fa9db..b3d8f1cd0c731 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -363,7 +363,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || opKind == AffineExprKind::CeilDiv) && "unexpected opKind"); - std::vector>>> stack; stack.emplace_back(expr, symbolPos, opKind, []() {}); From c70e79414ffef1d1f81a0d2d30e1d265aa4bdf5c Mon Sep 17 00:00:00 2001 From: lipracer Date: Sat, 28 Sep 2024 08:07:12 -0400 Subject: [PATCH 4/5] fix ci fail on windows platform --- mlir/lib/IR/AffineExpr.cpp | 98 +++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 33 deletions(-) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index b3d8f1cd0c731..ef037a2e13365 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -379,32 +379,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, // Note: Assignment must occur before pop, which will affect whether it // enters other execution branches. result = cast(expr).getValue() == 0; + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); stack.pop_back(); break; } case AffineExprKind::DimId: { result = false; + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); stack.pop_back(); break; } case AffineExprKind::SymbolId: { result = cast(expr).getPosition() == symbolPos; + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); stack.pop_back(); break; } // Checks divisibility by the given symbol for both operands. case AffineExprKind::Add: { AffineBinaryOpExpr binaryExpr = cast(expr); - stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind, - [&stack, &result, binaryExpr, symbolPos, opKind]() { - if (result) { - stack.emplace_back( - binaryExpr.getRHS(), symbolPos, opKind, - [&stack]() { stack.pop_back(); }); - } else { - stack.pop_back(); - } - }); + stack.emplace_back( + binaryExpr.getLHS(), symbolPos, opKind, + [&stack, &result, binaryExpr, symbolPos, opKind]() { + if (result) { + stack.emplace_back( + binaryExpr.getRHS(), symbolPos, opKind, [&stack]() { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + }); + } else { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + } + }); break; } // Checks divisibility by the given symbol for both operands. Consider the @@ -414,32 +426,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, // is `AffineExprKind::Mod` for this reason. case AffineExprKind::Mod: { AffineBinaryOpExpr binaryExpr = cast(expr); - stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, - [&stack, &result, binaryExpr, symbolPos]() { - if (result) { - stack.emplace_back( - binaryExpr.getRHS(), symbolPos, - AffineExprKind::Mod, - [&stack]() { stack.pop_back(); }); - } else { - stack.pop_back(); - } - }); + stack.emplace_back( + binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, + [&stack, &result, binaryExpr, symbolPos]() { + if (result) { + stack.emplace_back( + binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod, + [&stack]() { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + }); + } else { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + } + }); break; } // Checks if any of the operand divisible by the given symbol. case AffineExprKind::Mul: { AffineBinaryOpExpr binaryExpr = cast(expr); - stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind, - [&stack, &result, binaryExpr, symbolPos, opKind]() { - if (!result) { - stack.emplace_back( - binaryExpr.getRHS(), symbolPos, opKind, - [&stack]() { stack.pop_back(); }); - } else { - stack.pop_back(); - } - }); + stack.emplace_back( + binaryExpr.getLHS(), symbolPos, opKind, + [&stack, &result, binaryExpr, symbolPos, opKind]() { + if (!result) { + stack.emplace_back( + binaryExpr.getRHS(), symbolPos, opKind, [&stack]() { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + }); + } else { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + } + }); break; } // Floordiv and ceildiv are divisible by the given symbol when the first @@ -457,6 +481,8 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineBinaryOpExpr binaryExpr = cast(expr); if (opKind != expr.getKind()) { result = false; + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); stack.pop_back(); break; } @@ -464,12 +490,18 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, return std::get<0>(it).getKind() == AffineExprKind::Mul; })) { result = false; + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); stack.pop_back(); break; } - stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(), - [&stack]() { stack.pop_back(); }); + stack.emplace_back( + binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() { + llvm::detail::scope_exit> sexit( + std::move(std::get<3>(stack.back()))); + stack.pop_back(); + }); break; } llvm_unreachable("Unknown AffineExpr"); From 514d78c113ebf3340d0bafbad49baaaca655db97 Mon Sep 17 00:00:00 2001 From: "long.chen" Date: Wed, 2 Oct 2024 04:51:01 +0000 Subject: [PATCH 5/5] refine --- mlir/lib/IR/AffineExpr.cpp | 196 +++++++++++++++++++++---------------- 1 file changed, 111 insertions(+), 85 deletions(-) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index ef037a2e13365..82b0bc193bfb7 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -350,6 +350,82 @@ unsigned AffineDimExpr::getPosition() const { return static_cast(expr)->position; } +/// A manually managed stack used to convert recursive function calls into +/// looping utility classes during the access tree structure process. This has +/// two benefits: one is to access the current stack, and the other is to avoid +/// stack explosion when the recursion depth is too deep. Typically, recursive +/// calls take the form of the following: +/// push node +/// visit tree node +/// push node->left_node +/// ... +/// pop left_node +/// check result and do something +/// push node->right_node +/// pop right_node +/// pop node +/// ... +/// This form can be converted into the following form: +/// push node +/// visit tree node +/// push node->left_node +/// ... +/// pop left_node and do { +/// check result and do something +/// push node->right_node +/// pop right_node and do { pop node } +/// } +/// ... +/// so we need to perform some operations +/// after an element is pushed out of the stack. We use the `scope_exit` +/// structure to invoke these operations. +template +class CallStack { +public: + using value_type = + std::tuple>>; + CallStack(ArgsT... args) { + stack_.emplace_back(args..., []() {}); + } + + /// Push the parameters into the stack and record the operation to be executed + /// when the node access ends. By default, the previous stack element will pop + /// up. If you need to check the result of the current push to the stack, you + /// need to pass in a function and manually perform the push operation after + /// the function ends. + void pushArgs(ArgsT... args, const std::function &onExit = {}) { + if (onExit) + stack_.emplace_back(args..., onExit); + else + stack_.emplace_back(args..., [this]() { pop(); }); + } + + RetT getResult() const { return value_; } + + void returnValue(RetT value) { + value_ = value; + pop(); + } + + value_type &top() { return stack_.back(); } + + auto begin() const { return stack_.begin(); } + + auto end() const { return stack_.end(); } + + bool empty() const { return stack_.empty(); } + +private: + /// Note: We must move the top element of the stack and then perform the + /// stack pop operation. If we directly pop the stack, the `scope_exit` may + /// modify the stack, which may cause the program on the Windows platform to + /// crash, but it works normally on Ubuntu.git + void pop() { value_type _(stack_.pop_back_val()); } + + SmallVector stack_; + RetT value_; +}; + /// Returns true if the expression is divisible by the given symbol with /// position `symbolPos`. The argument `opKind` specifies here what kind of /// division or mod operation called this division. It helps in implementing the @@ -363,60 +439,40 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || opKind == AffineExprKind::CeilDiv) && "unexpected opKind"); - SmallVector>>> - stack; - stack.emplace_back(expr, symbolPos, opKind, []() {}); - bool result = false; + CallStack stack(expr, symbolPos, + opKind); while (!stack.empty()) { - AffineExpr expr = std::get<0>(stack.back()); - unsigned symbolPos = std::get<1>(stack.back()); - AffineExprKind opKind = std::get<2>(stack.back()); + AffineExpr expr = std::get<0>(stack.top()); + unsigned symbolPos = std::get<1>(stack.top()); + AffineExprKind opKind = std::get<2>(stack.top()); switch (expr.getKind()) { case AffineExprKind::Constant: { // Note: Assignment must occur before pop, which will affect whether it // enters other execution branches. - result = cast(expr).getValue() == 0; - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); + stack.returnValue(cast(expr).getValue() == 0); break; } case AffineExprKind::DimId: { - result = false; - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); + stack.returnValue(false); break; } case AffineExprKind::SymbolId: { - result = cast(expr).getPosition() == symbolPos; - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); + stack.returnValue(cast(expr).getPosition() == + symbolPos); break; } // Checks divisibility by the given symbol for both operands. case AffineExprKind::Add: { AffineBinaryOpExpr binaryExpr = cast(expr); - stack.emplace_back( - binaryExpr.getLHS(), symbolPos, opKind, - [&stack, &result, binaryExpr, symbolPos, opKind]() { - if (result) { - stack.emplace_back( - binaryExpr.getRHS(), symbolPos, opKind, [&stack]() { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - }); - } else { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - } - }); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind, + [&stack, binaryExpr, symbolPos, opKind]() { + if (stack.getResult()) + stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind); + else + stack.returnValue(stack.getResult()); + }); break; } // Checks divisibility by the given symbol for both operands. Consider the @@ -426,44 +482,26 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, // is `AffineExprKind::Mod` for this reason. case AffineExprKind::Mod: { AffineBinaryOpExpr binaryExpr = cast(expr); - stack.emplace_back( - binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, - [&stack, &result, binaryExpr, symbolPos]() { - if (result) { - stack.emplace_back( - binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod, - [&stack]() { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - }); - } else { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - } - }); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, + [&stack, binaryExpr, symbolPos]() { + if (stack.getResult()) + stack.pushArgs(binaryExpr.getRHS(), symbolPos, + AffineExprKind::Mod); + else + stack.returnValue(stack.getResult()); + }); break; } // Checks if any of the operand divisible by the given symbol. case AffineExprKind::Mul: { AffineBinaryOpExpr binaryExpr = cast(expr); - stack.emplace_back( - binaryExpr.getLHS(), symbolPos, opKind, - [&stack, &result, binaryExpr, symbolPos, opKind]() { - if (!result) { - stack.emplace_back( - binaryExpr.getRHS(), symbolPos, opKind, [&stack]() { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - }); - } else { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - } - }); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind, + [&stack, binaryExpr, symbolPos, opKind]() { + if (!stack.getResult()) + stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind); + else + stack.returnValue(stack.getResult()); + }); break; } // Floordiv and ceildiv are divisible by the given symbol when the first @@ -480,34 +518,22 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, case AffineExprKind::CeilDiv: { AffineBinaryOpExpr binaryExpr = cast(expr); if (opKind != expr.getKind()) { - result = false; - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); + stack.returnValue(false); break; } if (llvm::any_of(stack, [](auto &it) { return std::get<0>(it).getKind() == AffineExprKind::Mul; })) { - result = false; - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); + stack.returnValue(false); break; } - - stack.emplace_back( - binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() { - llvm::detail::scope_exit> sexit( - std::move(std::get<3>(stack.back()))); - stack.pop_back(); - }); + stack.pushArgs(binaryExpr.getLHS(), symbolPos, expr.getKind()); break; } llvm_unreachable("Unknown AffineExpr"); } } - return result; + return stack.getResult(); } /// Divides the given expression by the given symbol at position `symbolPos`. It