@@ -292,9 +292,10 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
292292 }
293293};
294294
295- // Pattern to eliminate ExecuteRegionOp results when it only forwards external
296- // values. It operates only on execute regions with single terminator yield
297- // operation.
295+ // Pattern to eliminate ExecuteRegionOp results which forward external
296+ // values from the region. In case there are multiple yield operations,
297+ // all of them must have the same operands iin order for the pattern to be
298+ // applicable.
298299struct ExecuteRegionForwardingEliminator
299300 : public OpRewritePattern<ExecuteRegionOp> {
300301 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
@@ -306,11 +307,8 @@ struct ExecuteRegionForwardingEliminator
306307
307308 SmallVector<Operation *> yieldOps;
308309 for (Block &block : op.getRegion ()) {
309- if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator ())) {
310- if (yield.getOperands ().empty ())
311- continue ;
310+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator ()))
312311 yieldOps.push_back (yield.getOperation ());
313- }
314312 }
315313
316314 if (yieldOps.empty ())
@@ -323,36 +321,52 @@ struct ExecuteRegionForwardingEliminator
323321 return failure ();
324322 }
325323
326- // Check if all yielded values are from outside the region
327- bool allExternal = true ;
328- for (Value yieldedValue : yieldOpsOperands) {
324+ SmallVector<Value> externalValues;
325+ SmallVector<Value> internalValues;
326+ SmallVector<Value> opResultsToReplaceWithExternalValues;
327+ SmallVector<Value> opResultsToKeep;
328+ for (auto [index, yieldedValue] : llvm::enumerate (yieldOpsOperands)) {
329329 if (isValueFromInsideRegion (yieldedValue, op)) {
330- allExternal = false ;
331- break ;
330+ internalValues.push_back (yieldedValue);
331+ opResultsToKeep.push_back (op.getResult (index));
332+ } else {
333+ externalValues.push_back (yieldedValue);
334+ opResultsToReplaceWithExternalValues.push_back (op.getResult (index));
332335 }
333336 }
334-
335- if (!allExternal )
337+ // No yeilded external values - nothing to do.
338+ if (externalValues. empty () )
336339 return failure ();
337340
338- // All yielded values are external - create a new execute_region with no
339- // results.
340- auto newOp = rewriter.create <ExecuteRegionOp>(op.getLoc (), TypeRange{});
341+ // There are yielded external values - create a new execute_region returning
342+ // just the internal values.
343+ SmallVector<Type> resultTypes;
344+ for (Value value : internalValues)
345+ resultTypes.push_back (value.getType ());
346+ auto newOp =
347+ rewriter.create <ExecuteRegionOp>(op.getLoc (), TypeRange (resultTypes));
341348 newOp->setAttrs (op->getAttrs ());
342349
343- // Move the region content to the new operation
344- newOp.getRegion ().takeBody (op.getRegion ());
350+ // Move old op's region to the new operation.
351+ rewriter.inlineRegionBefore (op.getRegion (), newOp.getRegion (),
352+ newOp.getRegion ().end ());
345353
346- // Replace all yield operations with a new yield operation with no results.
347- // scf.execute_region must have at least one yield operation.
354+ // Replace all yield operations with a new yield operation with updated
355+ // results. scf.execute_region must have at least one yield operation.
348356 for (auto *yieldOp : yieldOps) {
349357 rewriter.setInsertionPoint (yieldOp);
350358 rewriter.eraseOp (yieldOp);
351- rewriter.create <scf::YieldOp>(yieldOp->getLoc ());
359+ rewriter.create <scf::YieldOp>(yieldOp->getLoc (),
360+ ValueRange (internalValues));
352361 }
353362
354363 // Replace the old operation with the external values directly.
355- rewriter.replaceOp (op, yieldOpsOperands);
364+ rewriter.replaceAllUsesWith (opResultsToReplaceWithExternalValues,
365+ externalValues);
366+ // Replace the old operation's remaining results with the new operation's
367+ // results.
368+ rewriter.replaceAllUsesWith (opResultsToKeep, newOp.getResults ());
369+ rewriter.eraseOp (op);
356370 return success ();
357371 }
358372
0 commit comments