Skip to content

Commit 6328506

Browse files
authored
[flang][fir] Add rewrite pattern to convert fir.do_concurrent to fir.do_loop (#132207)
Rewrites `fir.do_concurrent` ops to a corresponding nest of `fir.do_loop ... unordered` ops.
1 parent 038cdd2 commit 6328506

File tree

4 files changed

+184
-1
lines changed

4 files changed

+184
-1
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3478,7 +3478,8 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent",
34783478
}
34793479

34803480
def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
3481-
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
3481+
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface,
3482+
["getLoopInductionVars"]>,
34823483
Terminator, NoTerminator, SingleBlock, ParentOneOf<["DoConcurrentOp"]>]> {
34833484
let summary = "do concurrent loop";
34843485

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4915,6 +4915,11 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
49154915
return mlir::success();
49164916
}
49174917

4918+
std::optional<llvm::SmallVector<mlir::Value>>
4919+
fir::DoConcurrentLoopOp::getLoopInductionVars() {
4920+
return llvm::SmallVector<mlir::Value>{getBody()->getArguments()};
4921+
}
4922+
49184923
//===----------------------------------------------------------------------===//
49194924
// FIROpsDialect
49204925
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
#include "flang/Optimizer/Builder/Todo.h"
1919
#include "flang/Optimizer/Dialect/FIROps.h"
2020
#include "flang/Optimizer/Transforms/Passes.h"
21+
#include "mlir/IR/IRMapping.h"
2122
#include "mlir/Pass/Pass.h"
2223
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
#include <optional>
2325

2426
namespace fir {
2527
#define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS
@@ -122,6 +124,57 @@ mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite(
122124
return mlir::failure();
123125
}
124126

127+
class DoConcurrentConversion
128+
: public mlir::OpRewritePattern<fir::DoConcurrentOp> {
129+
public:
130+
using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern;
131+
132+
mlir::LogicalResult
133+
matchAndRewrite(fir::DoConcurrentOp doConcurentOp,
134+
mlir::PatternRewriter &rewriter) const override {
135+
assert(doConcurentOp.getRegion().hasOneBlock());
136+
mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front();
137+
auto loop =
138+
mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator());
139+
assert(loop.getRegion().hasOneBlock());
140+
mlir::Block &loopBlock = loop.getRegion().getBlocks().front();
141+
142+
// Collect iteration variable(s) allocations do that we can move them
143+
// outside the `fir.do_concurrent` wrapper.
144+
llvm::SmallVector<mlir::Operation *> opsToMove;
145+
for (mlir::Operation &op : llvm::drop_end(wrapperBlock))
146+
opsToMove.push_back(&op);
147+
148+
fir::FirOpBuilder firBuilder(
149+
rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>());
150+
auto *allocIt = firBuilder.getAllocaBlock();
151+
152+
for (mlir::Operation *op : llvm::reverse(opsToMove))
153+
rewriter.moveOpBefore(op, allocIt, allocIt->begin());
154+
155+
rewriter.setInsertionPointAfter(doConcurentOp);
156+
fir::DoLoopOp innermostUnorderdLoop;
157+
mlir::SmallVector<mlir::Value> ivArgs;
158+
159+
for (auto [lb, ub, st, iv] :
160+
llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(),
161+
loop.getStep(), *loop.getLoopInductionVars())) {
162+
innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>(
163+
doConcurentOp.getLoc(), lb, ub, st,
164+
/*unordred=*/true, /*finalCountValue=*/false,
165+
/*iterArgs=*/std::nullopt, loop.getReduceOperands(),
166+
loop.getReduceAttrsAttr());
167+
ivArgs.push_back(innermostUnorderdLoop.getInductionVar());
168+
rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody());
169+
}
170+
171+
rewriter.inlineBlockBefore(
172+
&loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs);
173+
rewriter.eraseOp(doConcurentOp);
174+
return mlir::success();
175+
}
176+
};
177+
125178
void SimplifyFIROperationsPass::runOnOperation() {
126179
mlir::ModuleOp module = getOperation();
127180
mlir::MLIRContext &context = getContext();
@@ -142,4 +195,5 @@ void fir::populateSimplifyFIROperationsPatterns(
142195
mlir::RewritePatternSet &patterns, bool preferInlineImplementation) {
143196
patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>(
144197
patterns.getContext(), preferInlineImplementation);
198+
patterns.insert<DoConcurrentConversion>(patterns.getContext());
145199
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Tests converting `fir.do_concurrent` ops to `fir.do_loop` ops.
2+
3+
// RUN: fir-opt --split-input-file --simplify-fir-operations %s | FileCheck %s
4+
5+
func.func @dc_1d(%i_lb: index, %i_ub: index, %i_st: index) {
6+
fir.do_concurrent {
7+
%i = fir.alloca i32
8+
fir.do_concurrent.loop (%i_iv) = (%i_lb) to (%i_ub) step (%i_st) {
9+
%0 = fir.convert %i_iv : (index) -> i32
10+
fir.store %0 to %i : !fir.ref<i32>
11+
}
12+
}
13+
return
14+
}
15+
16+
// CHECK-LABEL: func.func @dc_1d(
17+
// CHECK-SAME: %[[I_LB:[^[:space:]]+]]: index,
18+
// CHECK-SAME: %[[I_UB:[^[:space:]]+]]: index,
19+
// CHECK-SAME: %[[I_ST:[^[:space:]]+]]: index) {
20+
21+
// CHECK: %[[I:.*]] = fir.alloca i32
22+
23+
// CHECK: fir.do_loop %[[I_IV:.*]] = %[[I_LB]] to %[[I_UB]] step %[[I_ST]] unordered {
24+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
25+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
26+
// CHECK: }
27+
28+
// CHECK: return
29+
// CHECK: }
30+
31+
// -----
32+
33+
func.func @dc_2d(%i_lb: index, %i_ub: index, %i_st: index,
34+
%j_lb: index, %j_ub: index, %j_st: index) {
35+
llvm.br ^bb1
36+
37+
^bb1:
38+
fir.do_concurrent {
39+
%i = fir.alloca i32
40+
%j = fir.alloca i32
41+
fir.do_concurrent.loop
42+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) {
43+
%0 = fir.convert %i_iv : (index) -> i32
44+
fir.store %0 to %i : !fir.ref<i32>
45+
46+
%1 = fir.convert %j_iv : (index) -> i32
47+
fir.store %1 to %j : !fir.ref<i32>
48+
}
49+
}
50+
return
51+
}
52+
53+
// CHECK-LABEL: func.func @dc_2d(
54+
// CHECK-SAME: %[[I_LB:[^[:space:]]+]]: index,
55+
// CHECK-SAME: %[[I_UB:[^[:space:]]+]]: index,
56+
// CHECK-SAME: %[[I_ST:[^[:space:]]+]]: index,
57+
// CHECK-SAME: %[[J_LB:[^[:space:]]+]]: index,
58+
// CHECK-SAME: %[[J_UB:[^[:space:]]+]]: index,
59+
// CHECK-SAME: %[[J_ST:[^[:space:]]+]]: index) {
60+
61+
// CHECK: %[[I:.*]] = fir.alloca i32
62+
// CHECK: %[[J:.*]] = fir.alloca i32
63+
// CHECK: llvm.br ^bb1
64+
65+
// CHECK: ^bb1:
66+
// CHECK: fir.do_loop %[[I_IV:.*]] = %[[I_LB]] to %[[I_UB]] step %[[I_ST]] unordered {
67+
// CHECK: fir.do_loop %[[J_IV:.*]] = %[[J_LB]] to %[[J_UB]] step %[[J_ST]] unordered {
68+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
69+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
70+
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
71+
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
72+
// CHECK: }
73+
// CHECK: }
74+
75+
// CHECK: return
76+
// CHECK: }
77+
78+
// -----
79+
80+
func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
81+
%j_lb: index, %j_ub: index, %j_st: index) {
82+
%sum = fir.alloca i32
83+
84+
fir.do_concurrent {
85+
%i = fir.alloca i32
86+
%j = fir.alloca i32
87+
fir.do_concurrent.loop
88+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
89+
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
90+
%0 = fir.convert %i_iv : (index) -> i32
91+
fir.store %0 to %i : !fir.ref<i32>
92+
93+
%1 = fir.convert %j_iv : (index) -> i32
94+
fir.store %1 to %j : !fir.ref<i32>
95+
}
96+
}
97+
return
98+
}
99+
100+
// CHECK-LABEL: func.func @dc_2d_reduction(
101+
// CHECK-SAME: %[[I_LB:[^[:space:]]+]]: index,
102+
// CHECK-SAME: %[[I_UB:[^[:space:]]+]]: index,
103+
// CHECK-SAME: %[[I_ST:[^[:space:]]+]]: index,
104+
// CHECK-SAME: %[[J_LB:[^[:space:]]+]]: index,
105+
// CHECK-SAME: %[[J_UB:[^[:space:]]+]]: index,
106+
// CHECK-SAME: %[[J_ST:[^[:space:]]+]]: index) {
107+
108+
// CHECK: %[[I:.*]] = fir.alloca i32
109+
// CHECK: %[[J:.*]] = fir.alloca i32
110+
// CHECK: %[[SUM:.*]] = fir.alloca i32
111+
112+
// CHECK: fir.do_loop %[[I_IV:.*]] = %[[I_LB]] to %[[I_UB]] step %[[I_ST]] unordered reduce({{.*}}<add>] -> %[[SUM]] : !fir.ref<i32>) {
113+
// CHECK: fir.do_loop %[[J_IV:.*]] = %[[J_LB]] to %[[J_UB]] step %[[J_ST]] unordered reduce({{.*}}<add>] -> %[[SUM]] : !fir.ref<i32>) {
114+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
115+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
116+
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
117+
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
118+
// CHECK: fir.result
119+
// CHECK: }
120+
// CHECK: fir.result
121+
// CHECK: }
122+
// CHECK: return
123+
// CHECK: }

0 commit comments

Comments
 (0)