@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
312312
313313} // namespace
314314
315+ static void propagateLoopAttrs (Operation *scfOp, Operation *brOp) {
316+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
317+ // llvm.loop_annotation attribute.
318+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
319+ // is the back-edge to the header block (conditionBlock)
320+ SmallVector<NamedAttribute> llvmAttrs;
321+ llvm::copy_if (scfOp->getAttrs (), std::back_inserter (llvmAttrs),
322+ [](auto attr) {
323+ return isa<LLVM::LLVMDialect>(attr.getValue ().getDialect ());
324+ });
325+ brOp->setDiscardableAttrs (llvmAttrs);
326+ }
327+
315328LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
316329 PatternRewriter &rewriter) const {
317330 Location loc = forOp.getLoc ();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
350363 auto branchOp =
351364 cf::BranchOp::create (rewriter, loc, conditionBlock, loopCarried);
352365
353- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
354- // llvm.loop_annotation attribute.
355- // LLVM requires the loop metadata to be attached on the "latch" block. Which
356- // is the back-edge to the header block (conditionBlock)
357- SmallVector<NamedAttribute> llvmAttrs;
358- llvm::copy_if (forOp->getAttrs (), std::back_inserter (llvmAttrs),
359- [](auto attr) {
360- return isa<LLVM::LLVMDialect>(attr.getValue ().getDialect ());
361- });
362- branchOp->setDiscardableAttrs (llvmAttrs);
363-
366+ propagateLoopAttrs (forOp, branchOp);
364367 rewriter.eraseOp (terminator);
365368
366369 // Compute loop bounds before branching to the condition.
@@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
589592
590593 rewriter.setInsertionPointToEnd (after);
591594 auto yieldOp = cast<scf::YieldOp>(after->getTerminator ());
592- rewriter.replaceOpWithNewOp <cf::BranchOp>(yieldOp, before,
593- yieldOp.getResults ());
595+ auto latch = rewriter.replaceOpWithNewOp <cf::BranchOp>(yieldOp, before,
596+ yieldOp.getResults ());
594597
598+ propagateLoopAttrs (whileOp, latch);
595599 // Replace the op with values "yielded" from the "before" region, which are
596600 // visible by dominance.
597601 rewriter.replaceOp (whileOp, args);
@@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
631635 // Loop around the "before" region based on condition.
632636 rewriter.setInsertionPointToEnd (before);
633637 auto condOp = cast<ConditionOp>(before->getTerminator ());
634- cf::CondBranchOp::create (rewriter, condOp. getLoc (), condOp. getCondition (),
635- before , condOp.getArgs (), continuation ,
636- ValueRange ());
638+ auto latch = cf::CondBranchOp::create (
639+ rewriter, condOp. getLoc () , condOp.getCondition (), before ,
640+ condOp. getArgs (), continuation, ValueRange ());
637641
642+ propagateLoopAttrs (whileOp, latch);
638643 // Replace the op with values "yielded" from the "before" region, which are
639644 // visible by dominance.
640645 rewriter.replaceOp (whileOp, condOp.getArgs ());
0 commit comments