Skip to content

Commit fbf81e3

Browse files
Enable attaching LLVM loop annotations to scf.for (#102562)
We recently discovered that the loop with a dynamic upper bound is unexpectedly unrolled during the NVVM to PTX process. By attaching the `llvm.loop_annotation`, we can control the unrolling behavior precisely. This PR enables the `cf.cond_br` to retain the loop annotation of `scf.for` after the `convert-scf-to-cf` pass. This change allows users to have precise control over the loop behavior during backend transformation. --------- Co-authored-by: Xiaolei Shi <[email protected]>
1 parent 3176f25 commit fbf81e3

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1819
#include "mlir/Dialect/SCF/IR/SCF.h"
1920
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
2021
#include "mlir/IR/Builders.h"
@@ -370,9 +371,18 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
370371
auto comparison = rewriter.create<arith::CmpIOp>(
371372
loc, arith::CmpIPredicate::slt, iv, upperBound);
372373

373-
rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
374-
ArrayRef<Value>(), endBlock,
375-
ArrayRef<Value>());
374+
auto condBranchOp = rewriter.create<cf::CondBranchOp>(
375+
loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock,
376+
ArrayRef<Value>());
377+
378+
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
379+
// llvm.loop_annotation attribute.
380+
SmallVector<NamedAttribute> llvmAttrs;
381+
llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
382+
[](auto attr) {
383+
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
384+
});
385+
condBranchOp->setDiscardableAttrs(llvmAttrs);
376386
// The result of the loop operation is the values of the condition block
377387
// arguments except the induction variable on the last iteration.
378388
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());

mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s
1+
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
44
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
@@ -675,3 +675,25 @@ func.func @forall(%num_threads: index) {
675675
}
676676
return
677677
}
678+
679+
// -----
680+
681+
// CHECK: #loop_unroll = #llvm.loop_unroll<disable = true>
682+
// CHECK-NEXT: #loop_unroll1 = #llvm.loop_unroll<full = true>
683+
// CHECK-NEXT: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #loop_unroll>
684+
// CHECK-NEXT: #[[FULL_UNROLL:.*]] = #llvm.loop_annotation<unroll = #loop_unroll1>
685+
// CHECK: cf.cond_br %{{.*}}, ^bb2, ^bb6 {llvm.loop_annotation = #[[NO_UNROLL]]}
686+
// CHECK: cf.cond_br %{{.*}}, ^bb4, ^bb5 {llvm.loop_annotation = #[[FULL_UNROLL]]}
687+
#no_unroll = #llvm.loop_annotation<unroll = <disable = true>>
688+
#full_unroll = #llvm.loop_annotation<unroll = <full = true>>
689+
func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 : index) {
690+
scf.for %i0 = %arg0 to %arg1 step %arg2 {
691+
%c0 = arith.constant 0 : index
692+
%c1 = arith.constant 1 : index
693+
%c4 = arith.constant 4 : index
694+
scf.for %i1 = %c0 to %c4 step %c1 {
695+
%c1_0 = arith.constant 1 : index
696+
} {llvm.loop_annotation = #full_unroll}
697+
} {llvm.loop_annotation = #no_unroll}
698+
return
699+
}

0 commit comments

Comments
 (0)