@@ -347,7 +347,20 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
347347 SmallVector<Value, 8 > loopCarried;
348348 loopCarried.push_back (stepped);
349349 loopCarried.append (terminator->operand_begin (), terminator->operand_end ());
350- rewriter.create <cf::BranchOp>(loc, conditionBlock, loopCarried);
350+ auto branchOp =
351+ rewriter.create <cf::BranchOp>(loc, conditionBlock, loopCarried);
352+
353+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
354+ // llvm.loop_annotation attribute.
355+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
356+ // is the back-edge to the header block (conditionBlock)
357+ SmallVector<NamedAttribute> llvmAttrs;
358+ llvm::copy_if (forOp->getAttrs (), std::back_inserter (llvmAttrs),
359+ [](auto attr) {
360+ return isa<LLVM::LLVMDialect>(attr.getValue ().getDialect ());
361+ });
362+ branchOp->setDiscardableAttrs (llvmAttrs);
363+
351364 rewriter.eraseOp (terminator);
352365
353366 // Compute loop bounds before branching to the condition.
@@ -369,18 +382,10 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
369382 auto comparison = rewriter.create <arith::CmpIOp>(
370383 loc, arith::CmpIPredicate::slt, iv, upperBound);
371384
372- auto condBranchOp = rewriter.create <cf::CondBranchOp>(
373- loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock,
374- ArrayRef<Value>());
385+ rewriter.create <cf::CondBranchOp>(loc, comparison, firstBodyBlock,
386+ ArrayRef<Value>(), endBlock,
387+ ArrayRef<Value>());
375388
376- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
377- // llvm.loop_annotation attribute.
378- SmallVector<NamedAttribute> llvmAttrs;
379- llvm::copy_if (forOp->getAttrs (), std::back_inserter (llvmAttrs),
380- [](auto attr) {
381- return isa<LLVM::LLVMDialect>(attr.getValue ().getDialect ());
382- });
383- condBranchOp->setDiscardableAttrs (llvmAttrs);
384389 // The result of the loop operation is the values of the condition block
385390 // arguments except the induction variable on the last iteration.
386391 rewriter.replaceOp (forOp, conditionBlock->getArguments ().drop_front ());
0 commit comments