@@ -291,9 +291,93 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291291 }
292292};
293293
294+ // Pattern to eliminate ExecuteRegionOp results when it only forwards external values.
295+ // It operates only on execute regions with single terminator yield operation.
296+ struct ExecuteRegionForwardingEliminator
297+ : public OpRewritePattern<ExecuteRegionOp> {
298+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
299+
300+ LogicalResult matchAndRewrite (ExecuteRegionOp op,
301+ PatternRewriter &rewriter) const override {
302+ if (op.getNumResults () == 0 )
303+ return failure ();
304+
305+ SmallVector<Operation *> yieldOps;
306+ for (Block &block : op.getRegion ()) {
307+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator ())) {
308+ if (yield.getResults ().empty ())
309+ continue ;
310+ yieldOps.push_back (yield.getOperation ());
311+ }
312+ }
313+
314+ if (yieldOps.size () != 1 )
315+ return failure ();
316+
317+ auto yieldOp = dyn_cast<scf::YieldOp>(yieldOps.front ());
318+ auto yieldedValues = yieldOp.getOperands ();
319+ // Check if all yielded values are from outside the region
320+ bool allExternal = true ;
321+ for (Value yieldedValue : yieldedValues) {
322+ if (isValueFromInsideRegion (yieldedValue, op)) {
323+ allExternal = false ;
324+ break ;
325+ }
326+ }
327+
328+ if (!allExternal)
329+ return failure ();
330+
331+ // All yielded values are external - create a new execute_region with no
332+ // results.
333+ auto newOp = rewriter.create <ExecuteRegionOp>(op.getLoc (), TypeRange{});
334+ newOp->setAttrs (op->getAttrs ());
335+
336+ // Move the region content to the new operation
337+ newOp.getRegion ().takeBody (op.getRegion ());
338+
339+ // Replace the yield operation with a new yield operation with no results.
340+ rewriter.setInsertionPoint (yieldOp);
341+ rewriter.eraseOp (yieldOp);
342+ rewriter.create <scf::YieldOp>(yieldOp.getLoc ());
343+
344+ // Replace the old operation with the external values directly.
345+ rewriter.replaceOp (op, yieldedValues);
346+ return success ();
347+ }
348+
349+ private:
350+ bool isValueFromInsideRegion (Value value,
351+ ExecuteRegionOp executeRegionOp) const {
352+ // Check if the value is defined within the execute_region
353+ if (Operation *defOp = value.getDefiningOp ()) {
354+ return executeRegionOp.getRegion ().isAncestor (defOp->getParentRegion ());
355+ }
356+
357+ // If it's a block argument, check if it's from within the region
358+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
359+ return executeRegionOp.getRegion ().isAncestor (blockArg.getParentRegion ());
360+ }
361+
362+ return false ; // Value is from outside the region
363+ }
364+ };
365+
294366void ExecuteRegionOp::getCanonicalizationPatterns (RewritePatternSet &results,
295367 MLIRContext *context) {
296- results.add <SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
368+ results.add <SingleBlockExecuteInliner, MultiBlockExecuteInliner, ExecuteRegionForwardingEliminator>(
369+ context);
370+ }
371+
372+ void ExecuteRegionOp::getEffects (
373+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
374+ &effects) {
375+ if (!getNoInline ())
376+ return ;
377+ // In case there is attribute no_inline we want the region not to be inlined
378+ // into the parent operation.
379+ effects.emplace_back (MemoryEffects::Write::get (),
380+ SideEffects::DefaultResource::get ());
297381}
298382
299383void ExecuteRegionOp::getSuccessorRegions (
0 commit comments