Skip to content

Commit af8be8c

Browse files
author
chenlonglong
committed
[mlir][affine] fix the issue of celidiv mul childiv expression not satisfying commutative
Fixes #107508
1 parent 9e73159 commit af8be8c

File tree

3 files changed

+140
-57
lines changed

3 files changed

+140
-57
lines changed

llvm/include/llvm/ADT/ScopeExit.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ template <typename Callable> class scope_exit {
3131
template <typename Fp>
3232
explicit scope_exit(Fp &&F) : ExitFunction(std::forward<Fp>(F)) {}
3333

34-
scope_exit(scope_exit &&Rhs)
35-
: ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) {
36-
Rhs.release();
37-
}
34+
scope_exit(scope_exit &&Rhs) { *this = std::move(Rhs); }
3835
scope_exit(const scope_exit &) = delete;
39-
scope_exit &operator=(scope_exit &&) = delete;
4036
scope_exit &operator=(const scope_exit &) = delete;
37+
scope_exit &operator=(scope_exit &&Rhs) {
38+
Engaged = std::exchange(Rhs.Engaged, false);
39+
ExitFunction = std::move(Rhs.ExitFunction);
40+
return *this;
41+
}
4142

4243
void release() { Engaged = false; }
4344

mlir/lib/IR/AffineExpr.cpp

Lines changed: 115 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <cmath>
1010
#include <cstdint>
1111
#include <limits>
12+
#include <numeric>
13+
#include <optional>
1214
#include <utility>
1315

1416
#include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
1820
#include "mlir/IR/IntegerSet.h"
1921
#include "mlir/Support/TypeID.h"
2022
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/ScopeExit.h"
2124
#include "llvm/Support/MathExtras.h"
22-
#include <numeric>
23-
#include <optional>
2425

2526
using namespace mlir;
2627
using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
362363
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
363364
opKind == AffineExprKind::CeilDiv) &&
364365
"unexpected opKind");
365-
switch (expr.getKind()) {
366-
case AffineExprKind::Constant:
367-
return cast<AffineConstantExpr>(expr).getValue() == 0;
368-
case AffineExprKind::DimId:
369-
return false;
370-
case AffineExprKind::SymbolId:
371-
return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
372-
// Checks divisibility by the given symbol for both operands.
373-
case AffineExprKind::Add: {
374-
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
375-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
376-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
377-
}
378-
// Checks divisibility by the given symbol for both operands. Consider the
379-
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
380-
// this is a division by s1 and both the operands of modulo are divisible by
381-
// s1 but it is not divisible by s1 always. The third argument is
382-
// `AffineExprKind::Mod` for this reason.
383-
case AffineExprKind::Mod: {
384-
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
385-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
386-
AffineExprKind::Mod) &&
387-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
388-
AffineExprKind::Mod);
389-
}
390-
// Checks if any of the operand divisible by the given symbol.
391-
case AffineExprKind::Mul: {
392-
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
393-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
394-
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
395-
}
396-
// Floordiv and ceildiv are divisible by the given symbol when the first
397-
// operand is divisible, and the affine expression kind of the argument expr
398-
// is same as the argument `opKind`. This can be inferred from commutative
399-
// property of floordiv and ceildiv operations and are as follow:
400-
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
401-
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
402-
// It will fail if operations are not same. For example:
403-
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
404-
case AffineExprKind::FloorDiv:
405-
case AffineExprKind::CeilDiv: {
406-
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
407-
if (opKind != expr.getKind())
408-
return false;
409-
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
410-
}
366+
std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
367+
llvm::detail::scope_exit<std::function<void(void)>>>>
368+
stack;
369+
stack.emplace_back(expr, symbolPos, opKind, []() {});
370+
bool result = false;
371+
372+
while (!stack.empty()) {
373+
AffineExpr expr = std::get<0>(stack.back());
374+
unsigned symbolPos = std::get<1>(stack.back());
375+
AffineExprKind opKind = std::get<2>(stack.back());
376+
377+
switch (expr.getKind()) {
378+
case AffineExprKind::Constant: {
379+
// Note: Assignment must occur before pop, which will affect whether it
380+
// enters other execution branches.
381+
result = cast<AffineConstantExpr>(expr).getValue() == 0;
382+
stack.pop_back();
383+
break;
384+
}
385+
case AffineExprKind::DimId: {
386+
result = false;
387+
stack.pop_back();
388+
break;
389+
}
390+
case AffineExprKind::SymbolId: {
391+
result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
392+
stack.pop_back();
393+
break;
394+
}
395+
// Checks divisibility by the given symbol for both operands.
396+
case AffineExprKind::Add: {
397+
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
398+
stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
399+
[&stack, &result, binaryExpr, symbolPos, opKind]() {
400+
if (result) {
401+
stack.emplace_back(
402+
binaryExpr.getRHS(), symbolPos, opKind,
403+
[&stack]() { stack.pop_back(); });
404+
} else {
405+
stack.pop_back();
406+
}
407+
});
408+
break;
409+
}
410+
// Checks divisibility by the given symbol for both operands. Consider the
411+
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
412+
// s1`, this is a division by s1 and both the operands of modulo are
413+
// divisible by s1 but it is not divisible by s1 always. The third argument
414+
// is `AffineExprKind::Mod` for this reason.
415+
case AffineExprKind::Mod: {
416+
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
417+
stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
418+
[&stack, &result, binaryExpr, symbolPos, opKind]() {
419+
if (result) {
420+
stack.emplace_back(
421+
binaryExpr.getRHS(), symbolPos,
422+
AffineExprKind::Mod,
423+
[&stack]() { stack.pop_back(); });
424+
} else {
425+
stack.pop_back();
426+
}
427+
});
428+
break;
429+
}
430+
// Checks if any of the operand divisible by the given symbol.
431+
case AffineExprKind::Mul: {
432+
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
433+
stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
434+
[&stack, &result, binaryExpr, symbolPos, opKind]() {
435+
if (!result) {
436+
stack.emplace_back(
437+
binaryExpr.getRHS(), symbolPos, opKind,
438+
[&stack]() { stack.pop_back(); });
439+
} else {
440+
stack.pop_back();
441+
}
442+
});
443+
break;
444+
}
445+
// Floordiv and ceildiv are divisible by the given symbol when the first
446+
// operand is divisible, and the affine expression kind of the argument expr
447+
// is same as the argument `opKind`. This can be inferred from commutative
448+
// property of floordiv and ceildiv operations and are as follow:
449+
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
450+
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
451+
// It will fail 1.if operations are not same. For example:
452+
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
453+
// multiplication operation in the expression. For example:
454+
// (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
455+
case AffineExprKind::FloorDiv:
456+
case AffineExprKind::CeilDiv: {
457+
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
458+
if (opKind != expr.getKind()) {
459+
result = false;
460+
stack.pop_back();
461+
break;
462+
}
463+
if (llvm::any_of(stack, [](auto &it) {
464+
return std::get<0>(it).getKind() == AffineExprKind::Mul;
465+
})) {
466+
result = false;
467+
stack.pop_back();
468+
break;
469+
}
470+
471+
stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
472+
[&stack]() { stack.pop_back(); });
473+
break;
474+
}
475+
llvm_unreachable("Unknown AffineExpr");
476+
}
411477
}
412-
llvm_unreachable("Unknown AffineExpr");
478+
return result;
413479
}
414480

415481
/// Divides the given expression by the given symbol at position `symbolPos`. It

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
308308
}
309309

310310
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
311-
// CHECK-LABEL: func @semiaffine_composite_floor
312-
func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
311+
// CHECK-LABEL: func @semiaffine_composite_ceildiv
312+
func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
313+
%a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
314+
// CHECK: %[[CST:.*]] = arith.constant 43
315+
return %a : index
316+
}
317+
318+
// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
319+
// CHECK-LABEL: func @semiaffine_composite_ceildiv
320+
func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
313321
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
314-
// CHECK: %[[CST:.*]] = arith.constant 47
322+
// CHECK-NOT: arith.constant
323+
return %a : index
324+
}
325+
326+
// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
327+
// CHECK-LABEL: func @semiaffine_composite_floordiv
328+
func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
329+
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
330+
// CHECK-NOT: arith.constant
315331
return %a : index
316332
}
317333

0 commit comments

Comments
 (0)