|
9 | 9 | #include <cmath> |
10 | 10 | #include <cstdint> |
11 | 11 | #include <limits> |
| 12 | +#include <numeric> |
| 13 | +#include <optional> |
12 | 14 | #include <utility> |
13 | 15 |
|
14 | 16 | #include "AffineExprDetail.h" |
|
18 | 20 | #include "mlir/IR/IntegerSet.h" |
19 | 21 | #include "mlir/Support/TypeID.h" |
20 | 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | +#include "llvm/ADT/ScopeExit.h" |
21 | 24 | #include "llvm/Support/MathExtras.h" |
22 | | -#include <numeric> |
23 | | -#include <optional> |
24 | 25 |
|
25 | 26 | using namespace mlir; |
26 | 27 | using namespace mlir::detail; |
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, |
362 | 363 | assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || |
363 | 364 | opKind == AffineExprKind::CeilDiv) && |
364 | 365 | "unexpected opKind"); |
365 | | - switch (expr.getKind()) { |
366 | | - case AffineExprKind::Constant: |
367 | | - return cast<AffineConstantExpr>(expr).getValue() == 0; |
368 | | - case AffineExprKind::DimId: |
369 | | - return false; |
370 | | - case AffineExprKind::SymbolId: |
371 | | - return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos); |
372 | | - // Checks divisibility by the given symbol for both operands. |
373 | | - case AffineExprKind::Add: { |
374 | | - AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
375 | | - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && |
376 | | - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); |
377 | | - } |
378 | | - // Checks divisibility by the given symbol for both operands. Consider the |
379 | | - // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, |
380 | | - // this is a division by s1 and both the operands of modulo are divisible by |
381 | | - // s1 but it is not divisible by s1 always. The third argument is |
382 | | - // `AffineExprKind::Mod` for this reason. |
383 | | - case AffineExprKind::Mod: { |
384 | | - AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
385 | | - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, |
386 | | - AffineExprKind::Mod) && |
387 | | - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, |
388 | | - AffineExprKind::Mod); |
389 | | - } |
390 | | - // Checks if any of the operand divisible by the given symbol. |
391 | | - case AffineExprKind::Mul: { |
392 | | - AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
393 | | - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || |
394 | | - isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); |
395 | | - } |
396 | | - // Floordiv and ceildiv are divisible by the given symbol when the first |
397 | | - // operand is divisible, and the affine expression kind of the argument expr |
398 | | - // is same as the argument `opKind`. This can be inferred from commutative |
399 | | - // property of floordiv and ceildiv operations and are as follow: |
400 | | - // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 |
401 | | - // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 |
402 | | - // It will fail if operations are not same. For example: |
403 | | - // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. |
404 | | - case AffineExprKind::FloorDiv: |
405 | | - case AffineExprKind::CeilDiv: { |
406 | | - AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
407 | | - if (opKind != expr.getKind()) |
408 | | - return false; |
409 | | - return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); |
410 | | - } |
| 366 | + std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind, |
| 367 | + llvm::detail::scope_exit<std::function<void(void)>>>> |
| 368 | + stack; |
| 369 | + stack.emplace_back(expr, symbolPos, opKind, []() {}); |
| 370 | + bool result = false; |
| 371 | + |
| 372 | + while (!stack.empty()) { |
| 373 | + AffineExpr expr = std::get<0>(stack.back()); |
| 374 | + unsigned symbolPos = std::get<1>(stack.back()); |
| 375 | + AffineExprKind opKind = std::get<2>(stack.back()); |
| 376 | + |
| 377 | + switch (expr.getKind()) { |
| 378 | + case AffineExprKind::Constant: { |
| 379 | + // Note: Assignment must occur before pop, which will affect whether it |
| 380 | + // enters other execution branches. |
| 381 | + result = cast<AffineConstantExpr>(expr).getValue() == 0; |
| 382 | + stack.pop_back(); |
| 383 | + break; |
| 384 | + } |
| 385 | + case AffineExprKind::DimId: { |
| 386 | + result = false; |
| 387 | + stack.pop_back(); |
| 388 | + break; |
| 389 | + } |
| 390 | + case AffineExprKind::SymbolId: { |
| 391 | + result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos; |
| 392 | + stack.pop_back(); |
| 393 | + break; |
| 394 | + } |
| 395 | + // Checks divisibility by the given symbol for both operands. |
| 396 | + case AffineExprKind::Add: { |
| 397 | + AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
| 398 | + stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind, |
| 399 | + [&stack, &result, binaryExpr, symbolPos, opKind]() { |
| 400 | + if (result) { |
| 401 | + stack.emplace_back( |
| 402 | + binaryExpr.getRHS(), symbolPos, opKind, |
| 403 | + [&stack]() { stack.pop_back(); }); |
| 404 | + } else { |
| 405 | + stack.pop_back(); |
| 406 | + } |
| 407 | + }); |
| 408 | + break; |
| 409 | + } |
| 410 | + // Checks divisibility by the given symbol for both operands. Consider the |
| 411 | + // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv |
| 412 | + // s1`, this is a division by s1 and both the operands of modulo are |
| 413 | + // divisible by s1 but it is not divisible by s1 always. The third argument |
| 414 | + // is `AffineExprKind::Mod` for this reason. |
| 415 | + case AffineExprKind::Mod: { |
| 416 | + AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
| 417 | + stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod, |
| 418 | + [&stack, &result, binaryExpr, symbolPos, opKind]() { |
| 419 | + if (result) { |
| 420 | + stack.emplace_back( |
| 421 | + binaryExpr.getRHS(), symbolPos, |
| 422 | + AffineExprKind::Mod, |
| 423 | + [&stack]() { stack.pop_back(); }); |
| 424 | + } else { |
| 425 | + stack.pop_back(); |
| 426 | + } |
| 427 | + }); |
| 428 | + break; |
| 429 | + } |
| 430 | + // Checks if any of the operand divisible by the given symbol. |
| 431 | + case AffineExprKind::Mul: { |
| 432 | + AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
| 433 | + stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind, |
| 434 | + [&stack, &result, binaryExpr, symbolPos, opKind]() { |
| 435 | + if (!result) { |
| 436 | + stack.emplace_back( |
| 437 | + binaryExpr.getRHS(), symbolPos, opKind, |
| 438 | + [&stack]() { stack.pop_back(); }); |
| 439 | + } else { |
| 440 | + stack.pop_back(); |
| 441 | + } |
| 442 | + }); |
| 443 | + break; |
| 444 | + } |
| 445 | + // Floordiv and ceildiv are divisible by the given symbol when the first |
| 446 | + // operand is divisible, and the affine expression kind of the argument expr |
| 447 | + // is same as the argument `opKind`. This can be inferred from commutative |
| 448 | + // property of floordiv and ceildiv operations and are as follow: |
| 449 | + // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 |
| 450 | + // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 |
| 451 | + // It will fail 1.if operations are not same. For example: |
| 452 | + // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a |
| 453 | + // multiplication operation in the expression. For example: |
| 454 | + // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified. |
| 455 | + case AffineExprKind::FloorDiv: |
| 456 | + case AffineExprKind::CeilDiv: { |
| 457 | + AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); |
| 458 | + if (opKind != expr.getKind()) { |
| 459 | + result = false; |
| 460 | + stack.pop_back(); |
| 461 | + break; |
| 462 | + } |
| 463 | + if (llvm::any_of(stack, [](auto &it) { |
| 464 | + return std::get<0>(it).getKind() == AffineExprKind::Mul; |
| 465 | + })) { |
| 466 | + result = false; |
| 467 | + stack.pop_back(); |
| 468 | + break; |
| 469 | + } |
| 470 | + |
| 471 | + stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(), |
| 472 | + [&stack]() { stack.pop_back(); }); |
| 473 | + break; |
| 474 | + } |
| 475 | + llvm_unreachable("Unknown AffineExpr"); |
| 476 | + } |
411 | 477 | } |
412 | | - llvm_unreachable("Unknown AffineExpr"); |
| 478 | + return result; |
413 | 479 | } |
414 | 480 |
|
415 | 481 | /// Divides the given expression by the given symbol at position `symbolPos`. It |
|
0 commit comments