Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 147 additions & 49 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <cmath>
#include <cstdint>
#include <limits>
#include <numeric>
#include <optional>
#include <utility>

#include "AffineExprDetail.h"
Expand All @@ -18,9 +20,8 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
#include <optional>

using namespace mlir;
using namespace mlir::detail;
Expand Down Expand Up @@ -362,54 +363,151 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
switch (expr.getKind()) {
case AffineExprKind::Constant:
return cast<AffineConstantExpr>(expr).getValue() == 0;
case AffineExprKind::DimId:
return false;
case AffineExprKind::SymbolId:
return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
// this is a division by s1 and both the operands of modulo are divisible by
// s1 but it is not divisible by s1 always. The third argument is
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
AffineExprKind::Mod) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
AffineExprKind::Mod);
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
// is same as the argument `opKind`. This can be inferred from commutative
// property of floordiv and ceildiv operations and are as follow:
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
// It will fail if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind())
return false;
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
}
SmallVector<std::tuple<AffineExpr, unsigned, AffineExprKind,
llvm::detail::scope_exit<std::function<void(void)>>>>
stack;
stack.emplace_back(expr, symbolPos, opKind, []() {});
bool result = false;

while (!stack.empty()) {
AffineExpr expr = std::get<0>(stack.back());
unsigned symbolPos = std::get<1>(stack.back());
AffineExprKind opKind = std::get<2>(stack.back());

switch (expr.getKind()) {
case AffineExprKind::Constant: {
// Note: Assignment must occur before pop, which will affect whether it
// enters other execution branches.
result = cast<AffineConstantExpr>(expr).getValue() == 0;
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
case AffineExprKind::DimId: {
result = false;
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
case AffineExprKind::SymbolId: {
result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.emplace_back(
binaryExpr.getLHS(), symbolPos, opKind,
[&stack, &result, binaryExpr, symbolPos, opKind]() {
if (result) {
stack.emplace_back(
binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
});
} else {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
}
});
break;
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
// s1`, this is a division by s1 and both the operands of modulo are
// divisible by s1 but it is not divisible by s1 always. The third argument
// is `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.emplace_back(
binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
[&stack, &result, binaryExpr, symbolPos]() {
if (result) {
stack.emplace_back(
binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod,
[&stack]() {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
});
} else {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
}
});
break;
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.emplace_back(
binaryExpr.getLHS(), symbolPos, opKind,
[&stack, &result, binaryExpr, symbolPos, opKind]() {
if (!result) {
stack.emplace_back(
binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
});
} else {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
}
});
break;
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
// is same as the argument `opKind`. This can be inferred from commutative
// property of floordiv and ceildiv operations and are as follow:
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
// It will fail 1.if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
// multiplication operation in the expression. For example:
// (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind()) {
result = false;
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
if (llvm::any_of(stack, [](auto &it) {
return std::get<0>(it).getKind() == AffineExprKind::Mul;
})) {
result = false;
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}

stack.emplace_back(
binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() {
llvm::detail::scope_exit<std::function<void(void)>> sexit(
std::move(std::get<3>(stack.back())));
stack.pop_back();
});
break;
}
llvm_unreachable("Unknown AffineExpr");
}
}
llvm_unreachable("Unknown AffineExpr");
return result;
}

/// Divides the given expression by the given symbol at position `symbolPos`. It
Expand Down
22 changes: 19 additions & 3 deletions mlir/test/Dialect/Affine/simplify-structures.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}

// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
// CHECK-LABEL: func @semiaffine_composite_floor
func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
// CHECK-LABEL: func @semiaffine_composite_ceildiv
func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = arith.constant 43
return %a : index
}

// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
// CHECK-LABEL: func @semiaffine_composite_ceildiv
func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = arith.constant 47
// CHECK-NOT: arith.constant
return %a : index
}

// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
// CHECK-LABEL: func @semiaffine_composite_floordiv
func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
// CHECK-NOT: arith.constant
return %a : index
}

Expand Down
Loading