|
26 | 26 | #include "mlir/Interfaces/ParallelCombiningOpInterface.h" |
27 | 27 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
28 | 28 | #include "mlir/Transforms/InliningUtils.h" |
| 29 | +#include "mlir/Transforms/RegionUtils.h" |
29 | 30 | #include "llvm/ADT/MapVector.h" |
30 | 31 | #include "llvm/ADT/STLExtras.h" |
31 | 32 | #include "llvm/ADT/SmallPtrSet.h" |
@@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() { |
3687 | 3688 | } |
3688 | 3689 |
|
3689 | 3690 | namespace { |
| 3691 | +/// Move a scf.if op that is directly before the scf.condition op in the while |
| 3692 | +/// before region, and whose condition matches the condition of the |
| 3693 | +/// scf.condition op, down into the while after region. |
| 3694 | +/// |
| 3695 | +/// scf.while (..) : (...) -> ... { |
| 3696 | +/// %additional_used_values = ... |
| 3697 | +/// %cond = ... |
| 3698 | +/// ... |
| 3699 | +/// %res = scf.if %cond -> (...) { |
| 3700 | +/// use(%additional_used_values) |
| 3701 | +/// ... // then block |
| 3702 | +/// scf.yield %then_value |
| 3703 | +/// } else { |
| 3704 | +/// scf.yield %else_value |
| 3705 | +/// } |
| 3706 | +/// scf.condition(%cond) %res, ... |
| 3707 | +/// } do { |
| 3708 | +/// ^bb0(%res_arg, ...): |
| 3709 | +/// use(%res_arg) |
| 3710 | +/// ... |
| 3711 | +/// |
| 3712 | +/// becomes |
| 3713 | +/// scf.while (..) : (...) -> ... { |
| 3714 | +/// %additional_used_values = ... |
| 3715 | +/// %cond = ... |
| 3716 | +/// ... |
| 3717 | +/// scf.condition(%cond) %else_value, ..., %additional_used_values |
| 3718 | +/// } do { |
| 3719 | +/// ^bb0(%res_arg ..., %additional_args): : |
| 3720 | +/// use(%additional_args) |
| 3721 | +/// ... // if then block |
| 3722 | +/// use(%then_value) |
| 3723 | +/// ... |
| 3724 | +struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { |
| 3725 | + using OpRewritePattern<scf::WhileOp>::OpRewritePattern; |
| 3726 | + |
| 3727 | + LogicalResult matchAndRewrite(scf::WhileOp op, |
| 3728 | + PatternRewriter &rewriter) const override { |
| 3729 | + auto conditionOp = op.getConditionOp(); |
| 3730 | + |
| 3731 | + // Only support ifOp right before the condition at the moment. Relaxing this |
| 3732 | + // would require to: |
| 3733 | + // - check that the body does not have side-effects conflicting with |
| 3734 | + // operations between the if and the condition. |
| 3735 | + // - check that results of the if operation are only used as arguments to |
| 3736 | + // the condition. |
| 3737 | + auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode()); |
| 3738 | + |
| 3739 | + // Check that the ifOp is directly before the conditionOp and that it |
| 3740 | + // matches the condition of the conditionOp. Also ensure that the ifOp has |
| 3741 | + // no else block with content, as that would complicate the transformation. |
| 3742 | + // TODO: support else blocks with content. |
| 3743 | + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || |
| 3744 | + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) |
| 3745 | + return failure(); |
| 3746 | + |
| 3747 | + assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && |
| 3748 | + *ifOp->user_begin() == conditionOp) && |
| 3749 | + "ifOp has unexpected uses"); |
| 3750 | + |
| 3751 | + Location loc = op.getLoc(); |
| 3752 | + |
| 3753 | + // Replace uses of ifOp results in the conditionOp with the yielded values |
| 3754 | + // from the ifOp branches. |
| 3755 | + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { |
| 3756 | + auto it = llvm::find(ifOp->getResults(), arg); |
| 3757 | + if (it != ifOp->getResults().end()) { |
| 3758 | + size_t ifOpIdx = it.getIndex(); |
| 3759 | + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); |
| 3760 | + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); |
| 3761 | + |
| 3762 | + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); |
| 3763 | + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); |
| 3764 | + } |
| 3765 | + } |
| 3766 | + |
| 3767 | + // Collect additional used values from before region. |
| 3768 | + SetVector<Value> additionalUsedValuesSet; |
| 3769 | + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { |
| 3770 | + if (&op.getBefore() == operand->get().getParentRegion()) |
| 3771 | + additionalUsedValuesSet.insert(operand->get()); |
| 3772 | + }); |
| 3773 | + |
| 3774 | + // Create new whileOp with additional used values as results. |
| 3775 | + auto additionalUsedValues = additionalUsedValuesSet.getArrayRef(); |
| 3776 | + auto additionalValueTypes = llvm::map_to_vector( |
| 3777 | + additionalUsedValues, [](Value val) { return val.getType(); }); |
| 3778 | + size_t additionalValueSize = additionalUsedValues.size(); |
| 3779 | + SmallVector<Type> newResultTypes(op.getResultTypes()); |
| 3780 | + newResultTypes.append(additionalValueTypes); |
| 3781 | + |
| 3782 | + auto newWhileOp = |
| 3783 | + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); |
| 3784 | + |
| 3785 | + rewriter.modifyOpInPlace(newWhileOp, [&] { |
| 3786 | + newWhileOp.getBefore().takeBody(op.getBefore()); |
| 3787 | + newWhileOp.getAfter().takeBody(op.getAfter()); |
| 3788 | + newWhileOp.getAfter().addArguments( |
| 3789 | + additionalValueTypes, |
| 3790 | + SmallVector<Location>(additionalValueSize, loc)); |
| 3791 | + }); |
| 3792 | + |
| 3793 | + rewriter.modifyOpInPlace(conditionOp, [&] { |
| 3794 | + conditionOp.getArgsMutable().append(additionalUsedValues); |
| 3795 | + }); |
| 3796 | + |
| 3797 | + // Replace uses of additional used values inside the ifOp then region with |
| 3798 | + // the whileOp after region arguments. |
| 3799 | + rewriter.replaceUsesWithIf( |
| 3800 | + additionalUsedValues, |
| 3801 | + newWhileOp.getAfterArguments().take_back(additionalValueSize), |
| 3802 | + [&](OpOperand &use) { |
| 3803 | + return ifOp.getThenRegion().isAncestor( |
| 3804 | + use.getOwner()->getParentRegion()); |
| 3805 | + }); |
| 3806 | + |
| 3807 | + // Inline ifOp then region into new whileOp after region. |
| 3808 | + rewriter.eraseOp(ifOp.thenYield()); |
| 3809 | + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), |
| 3810 | + newWhileOp.getAfterBody()->begin()); |
| 3811 | + rewriter.eraseOp(ifOp); |
| 3812 | + rewriter.replaceOp(op, |
| 3813 | + newWhileOp->getResults().drop_back(additionalValueSize)); |
| 3814 | + return success(); |
| 3815 | + } |
| 3816 | +}; |
| 3817 | + |
3690 | 3818 | /// Replace uses of the condition within the do block with true, since otherwise |
3691 | 3819 | /// the block would not be evaluated. |
3692 | 3820 | /// |
@@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, |
4399 | 4527 | results.add<RemoveLoopInvariantArgsFromBeforeBlock, |
4400 | 4528 | RemoveLoopInvariantValueYielded, WhileConditionTruth, |
4401 | 4529 | WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, |
4402 | | - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); |
| 4530 | + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( |
| 4531 | + context); |
4403 | 4532 | } |
4404 | 4533 |
|
4405 | 4534 | //===----------------------------------------------------------------------===// |
|
0 commit comments