@@ -291,9 +291,89 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291291 }
292292};
293293
294+ // Pattern to eliminate ExecuteRegionOp results when it only forwards external values.
295+ // It operates only on execute regions with single terminator yield operation.
296+ struct ExecuteRegionForwardingEliminator : public OpRewritePattern <ExecuteRegionOp> {
297+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
298+
299+ LogicalResult matchAndRewrite (ExecuteRegionOp op,
300+ PatternRewriter &rewriter) const override {
301+ if (op.getNumResults () == 0 )
302+ return failure ();
303+
304+ SmallVector<Operation*> yieldOps;
305+ for (Block &block : op.getRegion ()) {
306+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator ())) {
307+ if (yield.getResults ().empty ())
308+ continue ;
309+ yieldOps.push_back (yield.getOperation ());
310+ }
311+ }
312+
313+ if (yieldOps.size () != 1 )
314+ return failure ();
315+
316+ auto yieldOp = dyn_cast<scf::YieldOp>(yieldOps.front ());
317+ auto yieldedValues = yieldOp.getOperands ();
318+ // Check if all yielded values are from outside the region
319+ bool allExternal = true ;
320+ for (Value yieldedValue : yieldedValues) {
321+ if (isValueFromInsideRegion (yieldedValue, op)) {
322+ allExternal = false ;
323+ break ;
324+ }
325+ }
326+
327+ if (!allExternal)
328+ return failure ();
329+
330+ // All yielded values are external - create a new execute_region with no results.
331+ auto newOp = rewriter.create <ExecuteRegionOp>(op.getLoc (), TypeRange{});
332+ newOp->setAttrs (op->getAttrs ());
333+
334+ // Move the region content to the new operation
335+ newOp.getRegion ().takeBody (op.getRegion ());
336+
337+ // Replace the yield operation with a new yield operation with no results.
338+ rewriter.setInsertionPoint (yieldOp);
339+ rewriter.eraseOp (yieldOp);
340+ rewriter.create <scf::YieldOp>(yieldOp.getLoc ());
341+
342+ // Replace the old operation with the external values directly.
343+ rewriter.replaceOp (op, yieldedValues);
344+ return success ();
345+ }
346+
347+ private:
348+ bool isValueFromInsideRegion (Value value, ExecuteRegionOp executeRegionOp) const {
349+ // Check if the value is defined within the execute_region
350+ if (Operation *defOp = value.getDefiningOp ()) {
351+ return executeRegionOp.getRegion ().isAncestor (defOp->getParentRegion ());
352+ }
353+
354+ // If it's a block argument, check if it's from within the region
355+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
356+ return executeRegionOp.getRegion ().isAncestor (blockArg.getParentRegion ());
357+ }
358+
359+ return false ; // Value is from outside the region
360+ }
361+ };
362+
294363void ExecuteRegionOp::getCanonicalizationPatterns (RewritePatternSet &results,
295- MLIRContext *context) {
296- results.add <SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
364+ MLIRContext *context) {
365+ results.add <SingleBlockExecuteInliner, MultiBlockExecuteInliner,
366+ ExecuteRegionForwardingEliminator>(context);
367+ }
368+
369+ void ExecuteRegionOp::getEffects (
370+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
371+ &effects) {
372+ if (!getNoInline ())
373+ return ;
374+ // In case there is attribute no_inline we want the region not to be inlined into the parent operation.
375+ effects.emplace_back (MemoryEffects::Write::get (),
376+ SideEffects::DefaultResource::get ());
297377}
298378
299379void ExecuteRegionOp::getSuccessorRegions (
0 commit comments