Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 173 additions & 49 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <cmath>
#include <cstdint>
#include <limits>
#include <numeric>
#include <optional>
#include <utility>

#include "AffineExprDetail.h"
Expand All @@ -18,9 +20,8 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
#include <optional>

using namespace mlir;
using namespace mlir::detail;
Expand Down Expand Up @@ -349,6 +350,82 @@ unsigned AffineDimExpr::getPosition() const {
return static_cast<ImplType *>(expr)->position;
}

/// A manually managed stack used to convert recursive function calls into
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use AffineExpr::walk? It'll do a post-order traversal, so you could just have a stack of partial results. Unless I'm still missing something :-).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires pre order traversal. Perhaps we can refactor Affine Visitor to support pre order traversal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give me a small example/explanation where pre-order traversal is needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry,I didn't express clearly. The implementation of this isDivisibleBySymbol requires a pre-order traversal.
eg:

case add: visit(lhs) and vist(rhs)
case mul: visit(lhs) or vist(rhs)

base on the current expr type, using a control traversal approach, perhaps returning the interrupting and boolean types in the visitor can meet this requirement, but it also needs to be reimplemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a post-order traversal with a stack for intermediates, you'd have the results for lhs and rhs on the top of the stack when you visit an add, right? Can you give me a more complete example where this doesn't work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see no problem with using an explicit stack in addition to the implicit one in ::walk.

If you're concerned about stack overflows, the correct thing to do is to refactor ::walk to not be recursive, since that function is used in many places, so fixing this particular location isn't really sufficient if stack overflows are an issue with ::walk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring MLIR walk() to not be recursive is a long-lasting issue: patch welcome to fix this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.I will switch to a ‘walk’ approach and then refactor AffineExprVisitor in another patch using a stack to avoid recursive calls.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't use walk,
This is walk implementation.At this point,lhs and rhs have already been traversed.and I don't know where to do the push stack.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph The walk here is the traversal method of AffineExpr. When you mention mlir:: walk, do you mean about mlir::Operation?

/// looping utility classes during the access tree structure process. This has
/// two benefits: one is to access the current stack, and the other is to avoid
/// stack explosion when the recursion depth is too deep. Typically, recursive
/// calls take the form of the following:
/// push node
/// visit tree node
/// push node->left_node
/// ...
/// pop left_node
/// check result and do something
/// push node->right_node
/// pop right_node
/// pop node
/// ...
/// This form can be converted into the following form:
/// push node
/// visit tree node
/// push node->left_node
/// ...
/// pop left_node and do {
/// check result and do something
/// push node->right_node
/// pop right_node and do { pop node }
/// }
/// ...
/// so we need to perform some operations
/// after an element is pushed out of the stack. We use the `scope_exit`
/// structure to invoke these operations.
template <typename RetT, typename... ArgsT>
class CallStack {
public:
using value_type =
std::tuple<ArgsT..., llvm::detail::scope_exit<std::function<void(void)>>>;
CallStack(ArgsT... args) {
stack_.emplace_back(args..., []() {});
}

/// Push the parameters into the stack and record the operation to be executed
/// when the node access ends. By default, the previous stack element will pop
/// up. If you need to check the result of the current push to the stack, you
/// need to pass in a function and manually perform the push operation after
/// the function ends.
void pushArgs(ArgsT... args, const std::function<void(void)> &onExit = {}) {
if (onExit)
stack_.emplace_back(args..., onExit);
else
stack_.emplace_back(args..., [this]() { pop(); });
}

RetT getResult() const { return value_; }

void returnValue(RetT value) {
value_ = value;
pop();
}

value_type &top() { return stack_.back(); }

auto begin() const { return stack_.begin(); }

auto end() const { return stack_.end(); }

bool empty() const { return stack_.empty(); }

private:
/// Note: We must move the top element of the stack and then perform the
/// stack pop operation. If we directly pop the stack, the `scope_exit` may
/// modify the stack, which may cause the program on the Windows platform to
/// crash, but it works normally on Ubuntu.git
void pop() { value_type _(stack_.pop_back_val()); }

SmallVector<value_type> stack_;
RetT value_;
};

/// Returns true if the expression is divisible by the given symbol with
/// position `symbolPos`. The argument `opKind` specifies here what kind of
/// division or mod operation called this division. It helps in implementing the
Expand All @@ -362,54 +439,101 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
switch (expr.getKind()) {
case AffineExprKind::Constant:
return cast<AffineConstantExpr>(expr).getValue() == 0;
case AffineExprKind::DimId:
return false;
case AffineExprKind::SymbolId:
return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
// this is a division by s1 and both the operands of modulo are divisible by
// s1 but it is not divisible by s1 always. The third argument is
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
AffineExprKind::Mod) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
AffineExprKind::Mod);
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
// is same as the argument `opKind`. This can be inferred from commutative
// property of floordiv and ceildiv operations and are as follow:
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
// It will fail if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind())
return false;
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
}
CallStack<bool, AffineExpr, unsigned, AffineExprKind> stack(expr, symbolPos,
opKind);

while (!stack.empty()) {
AffineExpr expr = std::get<0>(stack.top());
unsigned symbolPos = std::get<1>(stack.top());
AffineExprKind opKind = std::get<2>(stack.top());

switch (expr.getKind()) {
case AffineExprKind::Constant: {
// Note: Assignment must occur before pop, which will affect whether it
// enters other execution branches.
stack.returnValue(cast<AffineConstantExpr>(expr).getValue() == 0);
break;
}
case AffineExprKind::DimId: {
stack.returnValue(false);
break;
}
case AffineExprKind::SymbolId: {
stack.returnValue(cast<AffineSymbolExpr>(expr).getPosition() ==
symbolPos);
break;
}
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind,
[&stack, binaryExpr, symbolPos, opKind]() {
if (stack.getResult())
stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind);
else
stack.returnValue(stack.getResult());
});
break;
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
// s1`, this is a division by s1 and both the operands of modulo are
// divisible by s1 but it is not divisible by s1 always. The third argument
// is `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.pushArgs(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
[&stack, binaryExpr, symbolPos]() {
if (stack.getResult())
stack.pushArgs(binaryExpr.getRHS(), symbolPos,
AffineExprKind::Mod);
else
stack.returnValue(stack.getResult());
});
break;
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind,
[&stack, binaryExpr, symbolPos, opKind]() {
if (!stack.getResult())
stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind);
else
stack.returnValue(stack.getResult());
});
break;
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
// is same as the argument `opKind`. This can be inferred from commutative
// property of floordiv and ceildiv operations and are as follow:
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
// It will fail 1.if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
// multiplication operation in the expression. For example:
// (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind()) {
stack.returnValue(false);
break;
}
if (llvm::any_of(stack, [](auto &it) {
return std::get<0>(it).getKind() == AffineExprKind::Mul;
})) {
stack.returnValue(false);
break;
}
stack.pushArgs(binaryExpr.getLHS(), symbolPos, expr.getKind());
break;
}
llvm_unreachable("Unknown AffineExpr");
}
}
llvm_unreachable("Unknown AffineExpr");
return stack.getResult();
}

/// Divides the given expression by the given symbol at position `symbolPos`. It
Expand Down
22 changes: 19 additions & 3 deletions mlir/test/Dialect/Affine/simplify-structures.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}

// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
// CHECK-LABEL: func @semiaffine_composite_floor
func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
// CHECK-LABEL: func @semiaffine_composite_ceildiv
func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = arith.constant 43
return %a : index
}

// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
// CHECK-LABEL: func @semiaffine_composite_ceildiv
func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = arith.constant 47
// CHECK-NOT: arith.constant
return %a : index
}

// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
// CHECK-LABEL: func @semiaffine_composite_floordiv
func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
// CHECK-NOT: arith.constant
return %a : index
}

Expand Down
Loading