99#include " mlir/Transforms/RegionUtils.h"
1010#include " mlir/Analysis/TopologicalSortUtils.h"
1111#include " mlir/IR/Block.h"
12- #include " mlir/IR/BuiltinOps.h"
1312#include " mlir/IR/IRMapping.h"
1413#include " mlir/IR/Operation.h"
1514#include " mlir/IR/PatternMatch.h"
1615#include " mlir/IR/RegionGraphTraits.h"
1716#include " mlir/IR/Value.h"
1817#include " mlir/Interfaces/ControlFlowInterfaces.h"
1918#include " mlir/Interfaces/SideEffectInterfaces.h"
20- #include " mlir/Support/LogicalResult.h"
2119
2220#include " llvm/ADT/DepthFirstIterator.h"
2321#include " llvm/ADT/PostOrderIterator.h"
24- #include " llvm/ADT/STLExtras.h"
25- #include " llvm/ADT/SmallSet.h"
2622
2723#include < deque>
28- #include < iterator>
2924
3025using namespace mlir ;
3126
@@ -679,91 +674,6 @@ static bool ableToUpdatePredOperands(Block *block) {
679674 return true ;
680675}
681676
682- // / Prunes the redundant list of arguments. E.g., if we are passing an argument
683- // / list like [x, y, z, x] this would return [x, y, z] and it would update the
684- // / `block` (to whom the argument are passed to) accordingly.
685- static SmallVector<SmallVector<Value, 8 >, 2 > pruneRedundantArguments (
686- const SmallVector<SmallVector<Value, 8 >, 2 > &newArguments,
687- RewriterBase &rewriter, Block *block) {
688-
689- SmallVector<SmallVector<Value, 8 >, 2 > newArgumentsPruned (
690- newArguments.size (), SmallVector<Value, 8 >());
691-
692- if (newArguments.empty ())
693- return newArguments;
694-
695- // `newArguments` is a 2D array of size `numLists` x `numArgs`
696- unsigned numLists = newArguments.size ();
697- unsigned numArgs = newArguments[0 ].size ();
698-
699- // Map that for each arg index contains the index that we can use in place of
700- // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
701- // idxToReplacement[3] = 0
702- llvm::DenseMap<unsigned , unsigned > idxToReplacement;
703-
704- // This is a useful data structure to track the first appearance of a Value
705- // on a given list of arguments
706- DenseMap<Value, unsigned > firstValueToIdx;
707- for (unsigned j = 0 ; j < numArgs; ++j) {
708- Value newArg = newArguments[0 ][j];
709- if (!firstValueToIdx.contains (newArg))
710- firstValueToIdx[newArg] = j;
711- }
712-
713- // Go through the first list of arguments (list 0).
714- for (unsigned j = 0 ; j < numArgs; ++j) {
715- bool shouldReplaceJ = false ;
716- unsigned replacement = 0 ;
717- // Look back to see if there are possible redundancies in list 0. Please
718- // note that we are using a map to annotate when an argument was seen first
719- // to avoid a O(N^2) algorithm. This has the drawback that if we have two
720- // lists like:
721- // list0: [%a, %a, %a]
722- // list1: [%c, %b, %b]
723- // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
724- // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725- // the number of arguments can be potentially unbounded we cannot afford a
726- // O(N^2) algorithm (to search to all the possible pairs) and we need to
727- // accept the trade-off.
728- unsigned k = firstValueToIdx[newArguments[0 ][j]];
729- if (k != j) {
730- shouldReplaceJ = true ;
731- replacement = k;
732- // If a possible redundancy is found, then scan the other lists: we
733- // can prune the arguments if and only if they are redundant in every
734- // list.
735- for (unsigned i = 1 ; i < numLists; ++i)
736- shouldReplaceJ =
737- shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
738- }
739- // Save the replacement.
740- if (shouldReplaceJ)
741- idxToReplacement[j] = replacement;
742- }
743-
744- // Populate the pruned argument list.
745- for (unsigned i = 0 ; i < numLists; ++i)
746- for (unsigned j = 0 ; j < numArgs; ++j)
747- if (!idxToReplacement.contains (j))
748- newArgumentsPruned[i].push_back (newArguments[i][j]);
749-
750- // Replace the block's redundant arguments.
751- SmallVector<unsigned > toErase;
752- for (auto [idx, arg] : llvm::enumerate (block->getArguments ())) {
753- if (idxToReplacement.contains (idx)) {
754- Value oldArg = block->getArgument (idx);
755- Value newArg = block->getArgument (idxToReplacement[idx]);
756- rewriter.replaceAllUsesWith (oldArg, newArg);
757- toErase.push_back (idx);
758- }
759- }
760-
761- // Erase the block's redundant arguments.
762- for (unsigned idxToErase : llvm::reverse (toErase))
763- block->eraseArgument (idxToErase);
764- return newArgumentsPruned;
765- }
766-
767677LogicalResult BlockMergeCluster::merge (RewriterBase &rewriter) {
768678 // Don't consider clusters that don't have blocks to merge.
769679 if (blocksToMerge.empty ())
@@ -812,10 +722,6 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
812722 }
813723 }
814724 }
815-
816- // Prune redundant arguments and update the leader block argument list
817- newArguments = pruneRedundantArguments (newArguments, rewriter, leaderBlock);
818-
819725 // Update the predecessors for each of the blocks.
820726 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
821727 for (auto predIt = block->pred_begin (), predE = block->pred_end ();
@@ -912,108 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
912818 return success (anyChanged);
913819}
914820
915- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
916- Block &block) {
917- SmallVector<size_t > argsToErase;
918-
919- // Go through the arguments of the block.
920- for (auto [argIdx, blockOperand] : llvm::enumerate (block.getArguments ())) {
921- bool sameArg = true ;
922- Value commonValue;
923-
924- // Go through the block predecessor and flag if they pass to the block
925- // different values for the same argument.
926- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
927- predIt != predE; ++predIt) {
928- auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
929- if (!branch) {
930- sameArg = false ;
931- break ;
932- }
933- unsigned succIndex = predIt.getSuccessorIndex ();
934- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
935- auto branchOperands = succOperands.getForwardedOperands ();
936- if (!commonValue) {
937- commonValue = branchOperands[argIdx];
938- } else {
939- if (branchOperands[argIdx] != commonValue) {
940- sameArg = false ;
941- break ;
942- }
943- }
944- }
945-
946- // If they are passing the same value, drop the argument.
947- if (commonValue && sameArg) {
948- argsToErase.push_back (argIdx);
949-
950- // Remove the argument from the block.
951- rewriter.replaceAllUsesWith (blockOperand, commonValue);
952- }
953- }
954-
955- // Remove the arguments.
956- for (auto argIdx : llvm::reverse (argsToErase)) {
957- block.eraseArgument (argIdx);
958-
959- // Remove the argument from the branch ops.
960- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
961- predIt != predE; ++predIt) {
962- auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
963- unsigned succIndex = predIt.getSuccessorIndex ();
964- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
965- succOperands.erase (argIdx);
966- }
967- }
968- return success (!argsToErase.empty ());
969- }
970-
971- // / This optimization drops redundant argument to blocks. I.e., if a given
972- // / argument to a block receives the same value from each of the block
973- // / predecessors, we can remove the argument from the block and use directly the
974- // / original value. This is a simple example:
975- // /
976- // / %cond = llvm.call @rand() : () -> i1
977- // / %val0 = llvm.mlir.constant(1 : i64) : i64
978- // / %val1 = llvm.mlir.constant(2 : i64) : i64
979- // / %val2 = llvm.mlir.constant(3 : i64) : i64
980- // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
981- // / : i64)
982- // /
983- // / ^bb1(%arg0 : i64, %arg1 : i64):
984- // / llvm.call @foo(%arg0, %arg1)
985- // /
986- // / The previous IR can be rewritten as:
987- // / %cond = llvm.call @rand() : () -> i1
988- // / %val0 = llvm.mlir.constant(1 : i64) : i64
989- // / %val1 = llvm.mlir.constant(2 : i64) : i64
990- // / %val2 = llvm.mlir.constant(3 : i64) : i64
991- // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
992- // /
993- // / ^bb1(%arg0 : i64):
994- // / llvm.call @foo(%val0, %arg0)
995- // /
996- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
997- MutableArrayRef<Region> regions) {
998- llvm::SmallSetVector<Region *, 1 > worklist;
999- for (Region ®ion : regions)
1000- worklist.insert (®ion);
1001- bool anyChanged = false ;
1002- while (!worklist.empty ()) {
1003- Region *region = worklist.pop_back_val ();
1004-
1005- // Add any nested regions to the worklist.
1006- for (Block &block : *region) {
1007- anyChanged = succeeded (dropRedundantArguments (rewriter, block));
1008-
1009- for (Operation &op : block)
1010- for (Region &nestedRegion : op.getRegions ())
1011- worklist.insert (&nestedRegion);
1012- }
1013- }
1014- return success (anyChanged);
1015- }
1016-
1017821// ===----------------------------------------------------------------------===//
1018822// Region Simplification
1019823// ===----------------------------------------------------------------------===//
@@ -1028,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1028832 bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
1029833 bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
1030834 bool mergedIdenticalBlocks = false ;
1031- bool droppedRedundantArguments = false ;
1032- if (mergeBlocks) {
835+ if (mergeBlocks)
1033836 mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
1034- droppedRedundantArguments =
1035- succeeded (dropRedundantArguments (rewriter, regions));
1036- }
1037837 return success (eliminatedBlocks || eliminatedOpsOrArgs ||
1038- mergedIdenticalBlocks || droppedRedundantArguments );
838+ mergedIdenticalBlocks);
1039839}
0 commit comments