@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
772772 });
773773}
774774
775+ SmallVector<Value>
776+ mlir::scf::replaceAndCastForOpIterArg (RewriterBase &rewriter, scf::ForOp forOp,
777+ OpOperand &operand, Value replacement,
778+ const ValueTypeCastFnTy &castFn) {
779+ assert (operand.getOwner () == forOp);
780+ Type oldType = operand.get ().getType (), newType = replacement.getType ();
781+
782+ // 1. Create new iter operands, exactly 1 is replaced.
783+ assert (operand.getOperandNumber () >= forOp.getNumControlOperands () &&
784+ " expected an iter OpOperand" );
785+ assert (operand.get ().getType () != replacement.getType () &&
786+ " Expected a different type" );
787+ SmallVector<Value> newIterOperands;
788+ for (OpOperand &opOperand : forOp.getInitArgsMutable ()) {
789+ if (opOperand.getOperandNumber () == operand.getOperandNumber ()) {
790+ newIterOperands.push_back (replacement);
791+ continue ;
792+ }
793+ newIterOperands.push_back (opOperand.get ());
794+ }
795+
796+ // 2. Create the new forOp shell.
797+ scf::ForOp newForOp = rewriter.create <scf::ForOp>(
798+ forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
799+ forOp.getStep (), newIterOperands);
800+ newForOp->setAttrs (forOp->getAttrs ());
801+ Block &newBlock = newForOp.getRegion ().front ();
802+ SmallVector<Value, 4 > newBlockTransferArgs (newBlock.getArguments ().begin (),
803+ newBlock.getArguments ().end ());
804+
805+ // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806+ // corresponding to the `replacement` value.
807+ OpBuilder::InsertionGuard g (rewriter);
808+ rewriter.setInsertionPoint (&newBlock, newBlock.begin ());
809+ BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg (
810+ &newForOp->getOpOperand (operand.getOperandNumber ()));
811+ Value castIn = castFn (rewriter, newForOp.getLoc (), oldType, newRegionIterArg);
812+ newBlockTransferArgs[newRegionIterArg.getArgNumber ()] = castIn;
813+
814+ // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815+ Block &oldBlock = forOp.getRegion ().front ();
816+ rewriter.mergeBlocks (&oldBlock, &newBlock, newBlockTransferArgs);
817+
818+ // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819+ auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator ());
820+ rewriter.setInsertionPoint (clonedYieldOp);
821+ unsigned yieldIdx =
822+ newRegionIterArg.getArgNumber () - forOp.getNumInductionVars ();
823+ Value castOut = castFn (rewriter, newForOp.getLoc (), newType,
824+ clonedYieldOp.getOperand (yieldIdx));
825+ SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands ();
826+ newYieldOperands[yieldIdx] = castOut;
827+ rewriter.create <scf::YieldOp>(newForOp.getLoc (), newYieldOperands);
828+ rewriter.eraseOp (clonedYieldOp);
829+
830+ // 6. Inject an outgoing cast op after the forOp.
831+ rewriter.setInsertionPointAfter (newForOp);
832+ SmallVector<Value> newResults = newForOp.getResults ();
833+ newResults[yieldIdx] =
834+ castFn (rewriter, newForOp.getLoc (), oldType, newResults[yieldIdx]);
835+
836+ return newResults;
837+ }
838+
775839namespace {
776840// Fold away ForOp iter arguments when:
777841// 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
9731037 }
9741038};
9751039
976- // / Perform a replacement of one iter OpOperand of an scf.for to the
977- // / `replacement` value which is expected to be the source of a tensor.cast.
978- // / tensor.cast ops are inserted inside the block to account for the type cast.
979- static SmallVector<Value>
980- replaceTensorCastForOpIterArg (PatternRewriter &rewriter, OpOperand &operand,
981- Value replacement) {
982- Type oldType = operand.get ().getType (), newType = replacement.getType ();
983- assert (llvm::isa<RankedTensorType>(oldType) &&
984- llvm::isa<RankedTensorType>(newType) &&
985- " expected ranked tensor types" );
986-
987- // 1. Create new iter operands, exactly 1 is replaced.
988- ForOp forOp = cast<ForOp>(operand.getOwner ());
989- assert (operand.getOperandNumber () >= forOp.getNumControlOperands () &&
990- " expected an iter OpOperand" );
991- assert (operand.get ().getType () != replacement.getType () &&
992- " Expected a different type" );
993- SmallVector<Value> newIterOperands;
994- for (OpOperand &opOperand : forOp.getInitArgsMutable ()) {
995- if (opOperand.getOperandNumber () == operand.getOperandNumber ()) {
996- newIterOperands.push_back (replacement);
997- continue ;
998- }
999- newIterOperands.push_back (opOperand.get ());
1000- }
1001-
1002- // 2. Create the new forOp shell.
1003- scf::ForOp newForOp = rewriter.create <scf::ForOp>(
1004- forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1005- forOp.getStep (), newIterOperands);
1006- newForOp->setAttrs (forOp->getAttrs ());
1007- Block &newBlock = newForOp.getRegion ().front ();
1008- SmallVector<Value, 4 > newBlockTransferArgs (newBlock.getArguments ().begin (),
1009- newBlock.getArguments ().end ());
1010-
1011- // 3. Inject an incoming cast op at the beginning of the block for the bbArg
1012- // corresponding to the `replacement` value.
1013- OpBuilder::InsertionGuard g (rewriter);
1014- rewriter.setInsertionPoint (&newBlock, newBlock.begin ());
1015- BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg (
1016- &newForOp->getOpOperand (operand.getOperandNumber ()));
1017- Value castIn = rewriter.create <tensor::CastOp>(newForOp.getLoc (), oldType,
1018- newRegionIterArg);
1019- newBlockTransferArgs[newRegionIterArg.getArgNumber ()] = castIn;
1020-
1021- // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1022- Block &oldBlock = forOp.getRegion ().front ();
1023- rewriter.mergeBlocks (&oldBlock, &newBlock, newBlockTransferArgs);
1024-
1025- // 5. Inject an outgoing cast op at the end of the block and yield it instead.
1026- auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator ());
1027- rewriter.setInsertionPoint (clonedYieldOp);
1028- unsigned yieldIdx =
1029- newRegionIterArg.getArgNumber () - forOp.getNumInductionVars ();
1030- Value castOut = rewriter.create <tensor::CastOp>(
1031- newForOp.getLoc (), newType, clonedYieldOp.getOperand (yieldIdx));
1032- SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands ();
1033- newYieldOperands[yieldIdx] = castOut;
1034- rewriter.create <scf::YieldOp>(newForOp.getLoc (), newYieldOperands);
1035- rewriter.eraseOp (clonedYieldOp);
1036-
1037- // 6. Inject an outgoing cast op after the forOp.
1038- rewriter.setInsertionPointAfter (newForOp);
1039- SmallVector<Value> newResults = newForOp.getResults ();
1040- newResults[yieldIdx] = rewriter.create <tensor::CastOp>(
1041- newForOp.getLoc (), oldType, newResults[yieldIdx]);
1042-
1043- return newResults;
1044- }
1045-
10461040// / Fold scf.for iter_arg/result pairs that go through incoming/ougoing
10471041// / a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
10481042// /
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
10901084 continue ;
10911085
10921086 // Create a new ForOp with that iter operand replaced.
1087+ ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
1088+ Value source) {
1089+ return b.create <tensor::CastOp>(loc, type, source);
1090+ };
10931091 rewriter.replaceOp (
1094- op, replaceTensorCastForOpIterArg (rewriter, iterOpOperand,
1095- incomingCast.getSource ()));
1092+ op, replaceAndCastForOpIterArg (rewriter, op , iterOpOperand,
1093+ incomingCast.getSource (), castFn ));
10961094 return success ();
10971095 }
10981096 return failure ();
0 commit comments