Skip to content

Commit 25c95eb

Browse files
authored
[flang][fir] Convert fir.do_loop with the unordered attribute to scf.parallel. (#168510)
Refines the existing conversion to allow `fir.do_loop` annotated with `unordered` to be lowered to `scf.parallel`, while other loops retain their original lowering.
1 parent d8ae4d5 commit 25c95eb

File tree

4 files changed

+107
-22
lines changed

4 files changed

+107
-22
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ std::unique_ptr<mlir::Pass> createVScaleAttrPass();
5353
std::unique_ptr<mlir::Pass>
5454
createVScaleAttrPass(std::pair<unsigned, unsigned> vscaleAttr);
5555

56+
void populateFIRToSCFRewrites(mlir::RewritePatternSet &patterns,
57+
bool parallelUnordered = false);
58+
5659
void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
5760
bool forceLoopToExecuteOnce = false,
5861
bool setNSW = true);

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def FIRToSCFPass : Pass<"fir-to-scf"> {
8585
let dependentDialects = [
8686
"fir::FIROpsDialect", "mlir::scf::SCFDialect"
8787
];
88+
let options = [Option<"parallelUnordered", "parallel-unordered", "bool",
89+
/*default=*/"false",
90+
"Allow converting a fir.do_loop with the `unordered` "
91+
"attribute to scf.parallel (experimental).">];
8892
}
8993

9094
def AnnotateConstantOperands : Pass<"annotate-constant"> {

flang/lib/Optimizer/Transforms/FIRToSCF.cpp

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,18 @@ class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
2525
struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
2626
using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
2727

28+
DoLoopConversion(mlir::MLIRContext *context,
29+
bool parallelUnorderedLoop = false,
30+
mlir::PatternBenefit benefit = 1)
31+
: OpRewritePattern<fir::DoLoopOp>(context, benefit),
32+
parallelUnorderedLoop(parallelUnorderedLoop) {}
33+
2834
mlir::LogicalResult
2935
matchAndRewrite(fir::DoLoopOp doLoopOp,
3036
mlir::PatternRewriter &rewriter) const override {
3137
mlir::Location loc = doLoopOp.getLoc();
3238
bool hasFinalValue = doLoopOp.getFinalValue().has_value();
39+
bool isUnordered = doLoopOp.getUnordered().has_value();
3340

3441
// Get loop values from the DoLoopOp
3542
mlir::Value low = doLoopOp.getLowerBound();
@@ -53,39 +60,54 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
5360
mlir::arith::DivSIOp::create(rewriter, loc, distance, step);
5461
auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
5562
auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1);
56-
auto scfForOp =
57-
mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs);
5863

64+
// Create the scf.for or scf.parallel operation
65+
mlir::Operation *scfLoopOp = nullptr;
66+
if (isUnordered && parallelUnorderedLoop) {
67+
scfLoopOp = mlir::scf::ParallelOp::create(rewriter, loc, {zero},
68+
{tripCount}, {one}, iterArgs);
69+
} else {
70+
scfLoopOp = mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one,
71+
iterArgs);
72+
}
73+
74+
// Move the body of the fir.do_loop to the scf.for or scf.parallel
5975
auto &loopOps = doLoopOp.getBody()->getOperations();
6076
auto resultOp =
6177
mlir::cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
6278
auto results = resultOp.getOperands();
63-
mlir::Block *loweredBody = scfForOp.getBody();
79+
auto scfLoopLikeOp = mlir::cast<mlir::LoopLikeOpInterface>(scfLoopOp);
80+
mlir::Block &scfLoopBody = scfLoopLikeOp.getLoopRegions().front()->front();
6481

65-
loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
66-
loopOps.begin(),
67-
std::prev(loopOps.end()));
82+
scfLoopBody.getOperations().splice(scfLoopBody.begin(), loopOps,
83+
loopOps.begin(),
84+
std::prev(loopOps.end()));
6885

69-
rewriter.setInsertionPointToStart(loweredBody);
86+
rewriter.setInsertionPointToStart(&scfLoopBody);
7087
mlir::Value iv = mlir::arith::MulIOp::create(
71-
rewriter, loc, scfForOp.getInductionVar(), step);
88+
rewriter, loc, scfLoopLikeOp.getSingleInductionVar().value(), step);
7289
iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv);
7390

7491
if (!results.empty()) {
75-
rewriter.setInsertionPointToEnd(loweredBody);
92+
rewriter.setInsertionPointToEnd(&scfLoopBody);
7693
mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results);
7794
}
7895
doLoopOp.getInductionVar().replaceAllUsesWith(iv);
79-
rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
80-
hasFinalValue
81-
? scfForOp.getRegionIterArgs().drop_front()
82-
: scfForOp.getRegionIterArgs());
83-
84-
// Copy all the attributes from the old to new op.
85-
scfForOp->setAttrs(doLoopOp->getAttrs());
86-
rewriter.replaceOp(doLoopOp, scfForOp);
96+
rewriter.replaceAllUsesWith(
97+
doLoopOp.getRegionIterArgs(),
98+
hasFinalValue ? scfLoopLikeOp.getRegionIterArgs().drop_front()
99+
: scfLoopLikeOp.getRegionIterArgs());
100+
101+
// Copy loop annotations from the fir.do_loop to scf loop op.
102+
if (auto ann = doLoopOp.getLoopAnnotation())
103+
scfLoopOp->setAttr("loop_annotation", *ann);
104+
105+
rewriter.replaceOp(doLoopOp, scfLoopOp);
87106
return mlir::success();
88107
}
108+
109+
private:
110+
bool parallelUnorderedLoop;
89111
};
90112

91113
struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
@@ -197,10 +219,15 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
197219
};
198220
} // namespace
199221

222+
void fir::populateFIRToSCFRewrites(mlir::RewritePatternSet &patterns,
223+
bool parallelUnordered) {
224+
patterns.add<IterWhileConversion, IfConversion>(patterns.getContext());
225+
patterns.add<DoLoopConversion>(patterns.getContext(), parallelUnordered);
226+
}
227+
200228
void FIRToSCFPass::runOnOperation() {
201229
mlir::RewritePatternSet patterns(&getContext());
202-
patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
203-
patterns.getContext());
230+
fir::populateFIRToSCFRewrites(patterns, parallelUnordered);
204231
walkAndApplyPatterns(getOperation(), std::move(patterns));
205232
}
206233

flang/test/Fir/FirToSCF/do-loop.fir

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: fir-opt %s --fir-to-scf | FileCheck %s
1+
// RUN: fir-opt %s --fir-to-scf --split-input-file | FileCheck %s --check-prefixes=CHECK,NO-PARALLEL
2+
// RUN: fir-opt %s --fir-to-scf='parallel-unordered' --split-input-file | FileCheck %s --check-prefixes=CHECK,PARALLEL
23

34
// CHECK-LABEL: func.func @simple_loop(
45
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
@@ -31,6 +32,8 @@ func.func @simple_loop(%arg0: !fir.ref<!fir.array<100xi32>>) {
3132
return
3233
}
3334

35+
// -----
36+
3437
// CHECK-LABEL: func.func @loop_with_negtive_step(
3538
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
3639
// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index
@@ -64,6 +67,8 @@ func.func @loop_with_negtive_step(%arg0: !fir.ref<!fir.array<100xi32>>) {
6467
return
6568
}
6669

70+
// -----
71+
6772
// CHECK-LABEL: func.func @loop_with_results(
6873
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
6974
// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
@@ -102,6 +107,8 @@ func.func @loop_with_results(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.r
102107
return
103108
}
104109

110+
// -----
111+
105112
// CHECK-LABEL: func.func @loop_with_final_value(
106113
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
107114
// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
@@ -146,6 +153,45 @@ func.func @loop_with_final_value(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !f
146153
return
147154
}
148155

156+
// -----
157+
158+
// CHECK-LABEL: func.func @loop_with_unordered_attr(
159+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
160+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
161+
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index
162+
// CHECK: %[[SHAPE_0:.*]] = fir.shape %[[CONSTANT_1]] : (index) -> !fir.shape<1>
163+
// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : i32
164+
// CHECK: %[[SUBI_0:.*]] = arith.subi %[[CONSTANT_1]], %[[CONSTANT_0]] : index
165+
// CHECK: %[[ADDI_0:.*]] = arith.addi %[[SUBI_0]], %[[CONSTANT_0]] : index
166+
// CHECK: %[[DIVSI_0:.*]] = arith.divsi %[[ADDI_0]], %[[CONSTANT_0]] : index
167+
// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : index
168+
// CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : index
169+
// PARALLEL: scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_3]]) to (%[[DIVSI_0]]) step (%[[CONSTANT_4]]) {
170+
// NO-PARALLEL: scf.for %[[VAL_0:.*]] = %[[CONSTANT_3]] to %[[DIVSI_0]] step %[[CONSTANT_4]] {
171+
// CHECK: %[[MULI_0:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_0]] : index
172+
// CHECK: %[[ADDI_1:.*]] = arith.addi %[[CONSTANT_0]], %[[MULI_0]] : index
173+
// CHECK: %[[ARRAY_COOR_0:.*]] = fir.array_coor %[[ARG0]](%[[SHAPE_0]]) %[[ADDI_1]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
174+
// CHECK: fir.store %[[CONSTANT_2]] to %[[ARRAY_COOR_0]] : !fir.ref<i32>
175+
// PARALLEL: scf.reduce
176+
// CHECK: }
177+
// CHECK: return
178+
// CHECK: }
179+
func.func @loop_with_unordered_attr(%arg0: !fir.ref<!fir.array<100xi32>>) {
180+
%c1 = arith.constant 1 : index
181+
%c100 = arith.constant 100 : index
182+
%0 = fir.shape %c100 : (index) -> !fir.shape<1>
183+
%c1_i32 = arith.constant 1 : i32
184+
fir.do_loop %arg1 = %c1 to %c100 step %c1 unordered {
185+
%1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
186+
fir.store %c1_i32 to %1 : !fir.ref<i32>
187+
}
188+
return
189+
}
190+
191+
// -----
192+
193+
// CHECK: #[[$ATTR_0:.+]] = #llvm.loop_vectorize<disable = false>
194+
// CHECK: #[[$ATTR_1:.+]] = #llvm.loop_annotation<vectorize = #[[$ATTR_0]]>
149195
// CHECK-LABEL: func.func @loop_with_attribute(
150196
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
151197
// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
@@ -167,16 +213,19 @@ func.func @loop_with_final_value(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !f
167213
// CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
168214
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : i32
169215
// CHECK: fir.store %[[VAL_16]] to %[[VAL_3]] : !fir.ref<i32>
170-
// CHECK: } {operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>, reduceAttrs = [#fir.reduce_attr<add>]}
216+
// CHECK: } {loop_annotation = #[[$ATTR_1]]}
171217
// CHECK: return
172218
// CHECK: }
219+
220+
#loop_vectorize = #llvm.loop_vectorize<disable = false>
221+
#loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize>
173222
func.func @loop_with_attribute(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
174223
%c1 = arith.constant 1 : index
175224
%c0_i32 = arith.constant 0 : i32
176225
%c100 = arith.constant 100 : index
177226
%0 = fir.alloca i32
178227
%1 = fir.shape %c100 : (index) -> !fir.shape<1>
179-
fir.do_loop %arg2 = %c1 to %c100 step %c1 reduce(#fir.reduce_attr<add> -> %0 : !fir.ref<i32>) {
228+
fir.do_loop %arg2 = %c1 to %c100 step %c1 attributes {loopAnnotation = #loop_annotation} {
180229
%2 = fir.array_coor %arg0(%1) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
181230
%3 = fir.load %2 : !fir.ref<i32>
182231
%4 = fir.load %0 : !fir.ref<i32>
@@ -187,6 +236,8 @@ func.func @loop_with_attribute(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir
187236
return
188237
}
189238

239+
// -----
240+
190241
// CHECK-LABEL: func.func @nested_loop(
191242
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100x100xi32>>) {
192243
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index

0 commit comments

Comments
 (0)