@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
772
772
});
773
773
}
774
774
775
+ SmallVector<Value>
776
+ mlir::scf::replaceAndCastForOpIterArg (RewriterBase &rewriter, scf::ForOp forOp,
777
+ OpOperand &operand, Value replacement,
778
+ const ValueTypeCastFnTy &castFn) {
779
+ assert (operand.getOwner () == forOp);
780
+ Type oldType = operand.get ().getType (), newType = replacement.getType ();
781
+
782
+ // 1. Create new iter operands, exactly 1 is replaced.
783
+ assert (operand.getOperandNumber () >= forOp.getNumControlOperands () &&
784
+ " expected an iter OpOperand" );
785
+ assert (operand.get ().getType () != replacement.getType () &&
786
+ " Expected a different type" );
787
+ SmallVector<Value> newIterOperands;
788
+ for (OpOperand &opOperand : forOp.getInitArgsMutable ()) {
789
+ if (opOperand.getOperandNumber () == operand.getOperandNumber ()) {
790
+ newIterOperands.push_back (replacement);
791
+ continue ;
792
+ }
793
+ newIterOperands.push_back (opOperand.get ());
794
+ }
795
+
796
+ // 2. Create the new forOp shell.
797
+ scf::ForOp newForOp = rewriter.create <scf::ForOp>(
798
+ forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
799
+ forOp.getStep (), newIterOperands);
800
+ newForOp->setAttrs (forOp->getAttrs ());
801
+ Block &newBlock = newForOp.getRegion ().front ();
802
+ SmallVector<Value, 4 > newBlockTransferArgs (newBlock.getArguments ().begin (),
803
+ newBlock.getArguments ().end ());
804
+
805
+ // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806
+ // corresponding to the `replacement` value.
807
+ OpBuilder::InsertionGuard g (rewriter);
808
+ rewriter.setInsertionPoint (&newBlock, newBlock.begin ());
809
+ BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg (
810
+ &newForOp->getOpOperand (operand.getOperandNumber ()));
811
+ Value castIn = castFn (rewriter, newForOp.getLoc (), oldType, newRegionIterArg);
812
+ newBlockTransferArgs[newRegionIterArg.getArgNumber ()] = castIn;
813
+
814
+ // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815
+ Block &oldBlock = forOp.getRegion ().front ();
816
+ rewriter.mergeBlocks (&oldBlock, &newBlock, newBlockTransferArgs);
817
+
818
+ // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819
+ auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator ());
820
+ rewriter.setInsertionPoint (clonedYieldOp);
821
+ unsigned yieldIdx =
822
+ newRegionIterArg.getArgNumber () - forOp.getNumInductionVars ();
823
+ Value castOut = castFn (rewriter, newForOp.getLoc (), newType,
824
+ clonedYieldOp.getOperand (yieldIdx));
825
+ SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands ();
826
+ newYieldOperands[yieldIdx] = castOut;
827
+ rewriter.create <scf::YieldOp>(newForOp.getLoc (), newYieldOperands);
828
+ rewriter.eraseOp (clonedYieldOp);
829
+
830
+ // 6. Inject an outgoing cast op after the forOp.
831
+ rewriter.setInsertionPointAfter (newForOp);
832
+ SmallVector<Value> newResults = newForOp.getResults ();
833
+ newResults[yieldIdx] =
834
+ castFn (rewriter, newForOp.getLoc (), oldType, newResults[yieldIdx]);
835
+
836
+ return newResults;
837
+ }
838
+
775
839
namespace {
776
840
// Fold away ForOp iter arguments when:
777
841
// 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
973
1037
}
974
1038
};
975
1039
976
- // / Perform a replacement of one iter OpOperand of an scf.for to the
977
- // / `replacement` value which is expected to be the source of a tensor.cast.
978
- // / tensor.cast ops are inserted inside the block to account for the type cast.
979
- static SmallVector<Value>
980
- replaceTensorCastForOpIterArg (PatternRewriter &rewriter, OpOperand &operand,
981
- Value replacement) {
982
- Type oldType = operand.get ().getType (), newType = replacement.getType ();
983
- assert (llvm::isa<RankedTensorType>(oldType) &&
984
- llvm::isa<RankedTensorType>(newType) &&
985
- " expected ranked tensor types" );
986
-
987
- // 1. Create new iter operands, exactly 1 is replaced.
988
- ForOp forOp = cast<ForOp>(operand.getOwner ());
989
- assert (operand.getOperandNumber () >= forOp.getNumControlOperands () &&
990
- " expected an iter OpOperand" );
991
- assert (operand.get ().getType () != replacement.getType () &&
992
- " Expected a different type" );
993
- SmallVector<Value> newIterOperands;
994
- for (OpOperand &opOperand : forOp.getInitArgsMutable ()) {
995
- if (opOperand.getOperandNumber () == operand.getOperandNumber ()) {
996
- newIterOperands.push_back (replacement);
997
- continue ;
998
- }
999
- newIterOperands.push_back (opOperand.get ());
1000
- }
1001
-
1002
- // 2. Create the new forOp shell.
1003
- scf::ForOp newForOp = rewriter.create <scf::ForOp>(
1004
- forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1005
- forOp.getStep (), newIterOperands);
1006
- newForOp->setAttrs (forOp->getAttrs ());
1007
- Block &newBlock = newForOp.getRegion ().front ();
1008
- SmallVector<Value, 4 > newBlockTransferArgs (newBlock.getArguments ().begin (),
1009
- newBlock.getArguments ().end ());
1010
-
1011
- // 3. Inject an incoming cast op at the beginning of the block for the bbArg
1012
- // corresponding to the `replacement` value.
1013
- OpBuilder::InsertionGuard g (rewriter);
1014
- rewriter.setInsertionPoint (&newBlock, newBlock.begin ());
1015
- BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg (
1016
- &newForOp->getOpOperand (operand.getOperandNumber ()));
1017
- Value castIn = rewriter.create <tensor::CastOp>(newForOp.getLoc (), oldType,
1018
- newRegionIterArg);
1019
- newBlockTransferArgs[newRegionIterArg.getArgNumber ()] = castIn;
1020
-
1021
- // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1022
- Block &oldBlock = forOp.getRegion ().front ();
1023
- rewriter.mergeBlocks (&oldBlock, &newBlock, newBlockTransferArgs);
1024
-
1025
- // 5. Inject an outgoing cast op at the end of the block and yield it instead.
1026
- auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator ());
1027
- rewriter.setInsertionPoint (clonedYieldOp);
1028
- unsigned yieldIdx =
1029
- newRegionIterArg.getArgNumber () - forOp.getNumInductionVars ();
1030
- Value castOut = rewriter.create <tensor::CastOp>(
1031
- newForOp.getLoc (), newType, clonedYieldOp.getOperand (yieldIdx));
1032
- SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands ();
1033
- newYieldOperands[yieldIdx] = castOut;
1034
- rewriter.create <scf::YieldOp>(newForOp.getLoc (), newYieldOperands);
1035
- rewriter.eraseOp (clonedYieldOp);
1036
-
1037
- // 6. Inject an outgoing cast op after the forOp.
1038
- rewriter.setInsertionPointAfter (newForOp);
1039
- SmallVector<Value> newResults = newForOp.getResults ();
1040
- newResults[yieldIdx] = rewriter.create <tensor::CastOp>(
1041
- newForOp.getLoc (), oldType, newResults[yieldIdx]);
1042
-
1043
- return newResults;
1044
- }
1045
-
1046
1040
// / Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1047
1041
// / a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1048
1042
// /
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1090
1084
continue ;
1091
1085
1092
1086
// Create a new ForOp with that iter operand replaced.
1087
+ ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
1088
+ Value source) {
1089
+ return b.create <tensor::CastOp>(loc, type, source);
1090
+ };
1093
1091
rewriter.replaceOp (
1094
- op, replaceTensorCastForOpIterArg (rewriter, iterOpOperand,
1095
- incomingCast.getSource ()));
1092
+ op, replaceAndCastForOpIterArg (rewriter, op , iterOpOperand,
1093
+ incomingCast.getSource (), castFn ));
1096
1094
return success ();
1097
1095
}
1098
1096
return failure ();
0 commit comments