@@ -372,15 +372,16 @@ static void generateUnrolledLoop(
372372 loopBodyBlock->getTerminator ()->setOperands (lastYielded);
373373}
374374
375- // / Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
376- LogicalResult mlir::loopUnrollByFactor (
375+ // / Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
376+ // / eplilog loop, if the loop is unrolled.
377+ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor (
377378 scf::ForOp forOp, uint64_t unrollFactor,
378379 function_ref<void (unsigned , Operation *, OpBuilder)> annotateFn) {
379380 assert (unrollFactor > 0 && " expected positive unroll factor" );
380381
381382 // Return if the loop body is empty.
382383 if (llvm::hasSingleElement (forOp.getBody ()->getOperations ()))
383- return success () ;
384+ return UnrolledLoopInfo{forOp, std:: nullopt } ;
384385
385386 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
386387 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -402,7 +403,7 @@ LogicalResult mlir::loopUnrollByFactor(
402403 if (*constTripCount == 1 &&
403404 failed (forOp.promoteIfSingleIteration (rewriter)))
404405 return failure ();
405- return success () ;
406+ return UnrolledLoopInfo{forOp, std:: nullopt } ;
406407 }
407408
408409 int64_t tripCountEvenMultiple =
@@ -450,6 +451,8 @@ LogicalResult mlir::loopUnrollByFactor(
450451 boundsBuilder.create <arith::MulIOp>(loc, step, unrollFactorCst);
451452 }
452453
454+ UnrolledLoopInfo resultLoops;
455+
453456 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
454457 if (generateEpilogueLoop) {
455458 OpBuilder epilogueBuilder (forOp->getContext ());
@@ -467,7 +470,8 @@ LogicalResult mlir::loopUnrollByFactor(
467470 }
468471 epilogueForOp->setOperands (epilogueForOp.getNumControlOperands (),
469472 epilogueForOp.getInitArgs ().size (), results);
470- (void )epilogueForOp.promoteIfSingleIteration (rewriter);
473+ if (epilogueForOp.promoteIfSingleIteration (rewriter).failed ())
474+ resultLoops.epilogueLoopOp = epilogueForOp;
471475 }
472476
473477 // Create unrolled loop.
@@ -489,8 +493,9 @@ LogicalResult mlir::loopUnrollByFactor(
489493 },
490494 annotateFn, iterArgs, yieldedValues);
491495 // Promote the loop body up if this has turned into a single iteration loop.
492- (void )forOp.promoteIfSingleIteration (rewriter);
493- return success ();
496+ if (forOp.promoteIfSingleIteration (rewriter).failed ())
497+ resultLoops.mainLoopOp = forOp;
498+ return resultLoops;
494499}
495500
496501// / Check if bounds of all inner loops are defined outside of `forOp`
0 commit comments