@@ -2665,38 +2665,125 @@ struct AffineIfSinking : public OpRewritePattern<AffineIfOp> {
2665
2665
auto par = dyn_cast<AffineParallelOp>(op->getParentOp ());
2666
2666
if (!par)
2667
2667
return failure ();
2668
+ if (!isa<AffineYieldOp>(op->getNextNode ()))
2669
+ return failure ();
2670
+
2671
+ bool failed = false ;
2672
+ op->walk ([&](Operation *sub) {
2673
+ if (sub != op) {
2674
+ for (auto oper : sub->getOperands ()) {
2675
+ if (par.getRegion ().isAncestor (((Value)oper).getParentRegion ()) &&
2676
+ !op.thenRegion ().isAncestor (((Value)oper).getParentRegion ())) {
2677
+ failed = true ;
2678
+ return ;
2679
+ }
2680
+ }
2681
+ }
2682
+ });
2683
+ if (failed)
2684
+ return failure ();
2685
+
2668
2686
if (par.getSteps ().size () != op.getIntegerSet ().getConstraints ().size ())
2669
2687
return failure ();
2688
+
2670
2689
for (auto cst : llvm::enumerate (op.getIntegerSet ().getConstraints ())) {
2671
2690
auto opd = cst.value ().dyn_cast <AffineDimExpr>();
2672
- if (!opd)
2691
+ if (!opd) {
2692
+ opd = (-cst.value ()).dyn_cast <AffineDimExpr>();
2693
+ }
2694
+ if (!opd) {
2673
2695
return failure ();
2674
- if (op.getOperands ()[opd.getPosition ()] != par.getIVs ()[cst.index ()])
2696
+ }
2697
+ if (op.getOperands ()[opd.getPosition ()] != par.getIVs ()[cst.index ()]) {
2675
2698
return failure ();
2676
- if (!op.getIntegerSet ().isEq (cst.index ()))
2699
+ }
2700
+ if (!op.getIntegerSet ().isEq (cst.index ())) {
2677
2701
return failure ();
2702
+ }
2678
2703
2679
2704
for (auto lb : par.getLowerBoundMap (cst.index ()).getResults ()) {
2680
2705
auto opd = lb.dyn_cast <AffineConstantExpr>();
2681
- if (!opd)
2706
+ if (!opd) {
2682
2707
return failure ();
2683
- if (opd.getValue () > 0 )
2708
+ }
2709
+ if (opd.getValue () > 0 ) {
2684
2710
return failure ();
2711
+ }
2685
2712
}
2686
2713
for (auto ub : par.getUpperBoundMap (cst.index ()).getResults ()) {
2687
2714
auto opd = ub.dyn_cast <AffineConstantExpr>();
2688
- if (!opd)
2715
+ if (!opd) {
2689
2716
return failure ();
2690
- if (opd.getValue () <= 0 )
2717
+ }
2718
+ if (opd.getValue () <= 0 ) {
2691
2719
return failure ();
2720
+ }
2721
+ }
2722
+ }
2723
+
2724
+ rewriter.eraseOp (op.getThenBlock ()->getTerminator ());
2725
+ rewriter.mergeBlockBefore (op.getThenBlock (), par->getNextNode ());
2726
+ rewriter.eraseOp (op);
2727
+ return success ();
2728
+ }
2729
+ };
2730
+
2731
+ static void replaceOpWithRegion (PatternRewriter &rewriter, Operation *op,
2732
+ Region ®ion, ValueRange blockArgs = {}) {
2733
+ assert (llvm::hasSingleElement (region) && " expected single-region block" );
2734
+ Block *block = ®ion.front ();
2735
+ Operation *terminator = block->getTerminator ();
2736
+ ValueRange results = terminator->getOperands ();
2737
+ rewriter.mergeBlockBefore (block, op, blockArgs);
2738
+ rewriter.replaceOp (op, results);
2739
+ rewriter.eraseOp (terminator);
2740
+ }
2741
+
2742
+ struct AffineIfSimplification : public OpRewritePattern <AffineIfOp> {
2743
+ using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2744
+
2745
+ LogicalResult matchAndRewrite (AffineIfOp op,
2746
+ PatternRewriter &rewriter) const override {
2747
+ SmallVector<AffineExpr> todo;
2748
+ bool knownFalse = false ;
2749
+ bool removed = false ;
2750
+ for (auto cst : llvm::enumerate (op.getIntegerSet ().getConstraints ())) {
2751
+ auto opd = cst.value ().dyn_cast <AffineConstantExpr>();
2752
+ if (!opd) {
2753
+ todo.push_back (cst.value ());
2754
+ continue ;
2755
+ }
2756
+ removed = true ;
2757
+
2758
+ if (op.getIntegerSet ().isEq (cst.index ())) {
2759
+ if (opd.getValue () != 0 ) {
2760
+ knownFalse = true ;
2761
+ break ;
2762
+ }
2763
+ }
2764
+ if (!(opd.getValue () >= 0 )) {
2765
+ knownFalse = true ;
2766
+ break ;
2692
2767
}
2693
2768
}
2694
- if (isa<AffineYieldOp>(op->getNextNode ())) {
2695
- rewriter.eraseOp (op.getThenBlock ()->getTerminator ());
2696
- rewriter.mergeBlockBefore (op.getThenBlock (), par->getNextNode ());
2697
- rewriter.eraseOp (op);
2769
+
2770
+ if (knownFalse) {
2771
+ todo.clear ();
2772
+ }
2773
+
2774
+ if (todo.size () == 0 ) {
2775
+
2776
+ if (!knownFalse)
2777
+ replaceOpWithRegion (rewriter, op, op.thenRegion ());
2778
+ else if (!op.elseRegion ().empty ())
2779
+ replaceOpWithRegion (rewriter, op, op.elseRegion ());
2780
+ else
2781
+ rewriter.eraseOp (op);
2782
+
2698
2783
return success ();
2699
2784
}
2785
+ // TODO can reduce the number of conditions, even if cannot eliminate
2786
+ // entirely.
2700
2787
return failure ();
2701
2788
}
2702
2789
};
@@ -2709,7 +2796,7 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
2709
2796
AlwaysAllocaScopeHoister<memref::AllocaScopeOp>,
2710
2797
AlwaysAllocaScopeHoister<scf::ForOp>,
2711
2798
AlwaysAllocaScopeHoister<AffineForOp>, ConstantRankReduction,
2712
- AffineIfSinking,
2799
+ AffineIfSinking, AffineIfSimplification,
2713
2800
// RankReduction<memref::AllocaOp, scf::ParallelOp>,
2714
2801
AggressiveAllocaScopeInliner, InductiveVarRemoval>(context);
2715
2802
}
0 commit comments