@@ -132,7 +132,6 @@ struct FuseForalls final : OpRewritePattern<scf::ForallOp> {
132132
133133private:
134134 int64_t flatWorkgroupSize;
135- int64_t subgroupSize;
136135};
137136
138137struct FuseTilableDestinationProducers final : OpRewritePattern<scf::ForallOp> {
@@ -174,6 +173,68 @@ struct FuseTilableDestinationProducers final : OpRewritePattern<scf::ForallOp> {
174173 }
175174};
176175
176+ struct FuseUnitLoopDestination final : OpRewritePattern<scf::ForallOp> {
177+ using OpRewritePattern::OpRewritePattern;
178+ LogicalResult matchAndRewrite (scf::ForallOp forallOp,
179+ PatternRewriter &rewriter) const override {
180+ std::optional<int64_t > maybeTripCount = getStaticForallTripCount (forallOp);
181+ if (!maybeTripCount || *maybeTripCount != 1 ) {
182+ return rewriter.notifyMatchFailure (forallOp,
183+ " not a unit trip count loop" );
184+ }
185+ DestinationStyleOpInterface dpsProducer;
186+ BlockArgument bodyArg;
187+ Value dpsResult;
188+ for (auto iterArg : forallOp.getRegionIterArgs ()) {
189+ dpsResult = forallOp.getTiedLoopInit (iterArg)->get ();
190+ bodyArg = iterArg;
191+ dpsProducer = dpsResult.getDefiningOp <DestinationStyleOpInterface>();
192+ if (dpsProducer) {
193+ break ;
194+ }
195+ }
196+ if (!dpsProducer || !dpsProducer->hasOneUse ()) {
197+ return rewriter.notifyMatchFailure (forallOp,
198+ " no single use DPS producer" );
199+ }
200+
201+ Operation *parallelInsert = nullptr ;
202+ for (auto user : bodyArg.getUsers ()) {
203+ if (isa<tensor::ParallelInsertSliceOp>(user)) {
204+ // This should be illegal but check anyway.
205+ if (parallelInsert) {
206+ return rewriter.notifyMatchFailure (forallOp, " multiple insert users" );
207+ }
208+ parallelInsert = user;
209+ }
210+ }
211+ if (!parallelInsert) {
212+ return rewriter.notifyMatchFailure (
213+ forallOp, " destination not used by a parallel insert" );
214+ }
215+
216+ rewriter.startOpModification (forallOp);
217+ // Move the producer into the body of the forall loop.
218+ rewriter.moveOpBefore (dpsProducer, forallOp.getBody (),
219+ forallOp.getBody ()->begin ());
220+
221+ // Replace all uses of the region iter arg with the moved dps op.
222+ rewriter.replaceAllUsesExcept (bodyArg, dpsResult, parallelInsert);
223+
224+ // Set the init operand of the forall op to the init operand of the
225+ // producer.
226+ int64_t dpsInitIndex = cast<OpResult>(dpsResult).getResultNumber ();
227+ forallOp->setOperand (forallOp.getTiedOpOperand (bodyArg)->getOperandNumber (),
228+ dpsProducer.getDpsInitOperand (dpsInitIndex)->get ());
229+
230+ // Finally replace the init operand of the moved producer with the region
231+ // iter arg.
232+ dpsProducer.setDpsInitOperand (dpsInitIndex, bodyArg);
233+ rewriter.finalizeOpModification (forallOp);
234+ return success ();
235+ }
236+ };
237+
177238struct FuseTilableSliceProducers final
178239 : OpRewritePattern<tensor::ExtractSliceOp> {
179240 using OpRewritePattern::OpRewritePattern;
@@ -290,6 +351,7 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
290351 {
291352 RewritePatternSet patterns (context);
292353 patterns.add <FuseTilableDestinationProducers>(context);
354+ patterns.add <FuseUnitLoopDestination>(context);
293355 patterns.add <FuseTilableForallConsumers>(context);
294356 tensor::populateFoldTensorEmptyPatterns (patterns);
295357 scf::ForallOp::getCanonicalizationPatterns (patterns, context);
0 commit comments