@@ -291,9 +291,94 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291291 }
292292};
293293
294+ // Pattern to eliminate ExecuteRegionOp results when it only forwards external
295+ // values. It operates only on execute regions with single terminator yield
296+ // operation.
297+ struct ExecuteRegionForwardingEliminator
298+ : public OpRewritePattern<ExecuteRegionOp> {
299+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
300+
301+ LogicalResult matchAndRewrite (ExecuteRegionOp op,
302+ PatternRewriter &rewriter) const override {
303+ if (op.getNumResults () == 0 )
304+ return failure ();
305+
306+ SmallVector<Operation *> yieldOps;
307+ for (Block &block : op.getRegion ()) {
308+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator ())) {
309+ if (yield.getResults ().empty ())
310+ continue ;
311+ yieldOps.push_back (yield.getOperation ());
312+ }
313+ }
314+
315+ if (yieldOps.size () != 1 )
316+ return failure ();
317+
318+ auto yieldOp = dyn_cast<scf::YieldOp>(yieldOps.front ());
319+ auto yieldedValues = yieldOp.getOperands ();
320+ // Check if all yielded values are from outside the region
321+ bool allExternal = true ;
322+ for (Value yieldedValue : yieldedValues) {
323+ if (isValueFromInsideRegion (yieldedValue, op)) {
324+ allExternal = false ;
325+ break ;
326+ }
327+ }
328+
329+ if (!allExternal)
330+ return failure ();
331+
332+ // All yielded values are external - create a new execute_region with no
333+ // results.
334+ auto newOp = rewriter.create <ExecuteRegionOp>(op.getLoc (), TypeRange{});
335+ newOp->setAttrs (op->getAttrs ());
336+
337+ // Move the region content to the new operation
338+ newOp.getRegion ().takeBody (op.getRegion ());
339+
340+ // Replace the yield operation with a new yield operation with no results.
341+ rewriter.setInsertionPoint (yieldOp);
342+ rewriter.eraseOp (yieldOp);
343+ rewriter.create <scf::YieldOp>(yieldOp.getLoc ());
344+
345+ // Replace the old operation with the external values directly.
346+ rewriter.replaceOp (op, yieldedValues);
347+ return success ();
348+ }
349+
350+ private:
351+ bool isValueFromInsideRegion (Value value,
352+ ExecuteRegionOp executeRegionOp) const {
353+ // Check if the value is defined within the execute_region
354+ if (Operation *defOp = value.getDefiningOp ()) {
355+ return executeRegionOp.getRegion ().isAncestor (defOp->getParentRegion ());
356+ }
357+
358+ // If it's a block argument, check if it's from within the region
359+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
360+ return executeRegionOp.getRegion ().isAncestor (blockArg.getParentRegion ());
361+ }
362+
363+ return false ; // Value is from outside the region
364+ }
365+ };
366+
294367void ExecuteRegionOp::getCanonicalizationPatterns (RewritePatternSet &results,
295368 MLIRContext *context) {
296- results.add <SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
369+ results.add <SingleBlockExecuteInliner, MultiBlockExecuteInliner,
370+ ExecuteRegionForwardingEliminator>(context);
371+ }
372+
373+ void ExecuteRegionOp::getEffects (
374+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
375+ &effects) {
376+ if (!getNoInline ())
377+ return ;
378+ // In case there is attribute no_inline we want the region not to be inlined
379+ // into the parent operation.
380+ effects.emplace_back (MemoryEffects::Write::get (),
381+ SideEffects::DefaultResource::get ());
297382}
298383
299384void ExecuteRegionOp::getSuccessorRegions (
0 commit comments