1+ // ===-- SimdOnly.cpp ------------------------------------------------------===//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+
19#include " flang/Optimizer/Builder/FIRBuilder.h"
210#include " flang/Optimizer/Transforms/Utils.h"
11+ #include " mlir/Dialect/Arith/IR/Arith.h"
312#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
413#include " mlir/Dialect/Func/IR/FuncOps.h"
514#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
6- #include " mlir/IR/IRMapping.h"
15+ #include " mlir/IR/MLIRContext.h"
16+ #include " mlir/IR/Operation.h"
17+ #include " mlir/IR/PatternMatch.h"
718#include " mlir/Pass/Pass.h"
8- #include " mlir/Transforms/DialectConversion .h"
19+ #include " mlir/Support/LLVM .h"
920#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
10- #include < llvm/Support/Debug.h>
11- #include < mlir/IR/MLIRContext.h>
12- #include < mlir/IR/Operation.h>
13- #include < mlir/IR/PatternMatch.h>
14- #include < mlir/Support/LLVM.h>
21+ #include " llvm/Support/Debug.h"
1522
1623namespace flangomp {
1724#define GEN_PASS_DEF_SIMDONLYPASS
@@ -44,8 +51,15 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
4451 return rewriter.notifyMatchFailure (op, " Op is a plain SimdOp" );
4552 }
4653
47- if (op->getParentOfType <mlir::omp::SimdOp>())
48- return rewriter.notifyMatchFailure (op, " Op is nested under a SimdOp" );
54+ if (op->getParentOfType <mlir::omp::SimdOp>() &&
55+ (mlir::isa<mlir::omp::YieldOp>(op) ||
56+ mlir::isa<mlir::omp::LoopNestOp>(op) ||
57+ mlir::isa<mlir::omp::WsloopOp>(op) ||
58+ mlir::isa<mlir::omp::WorkshareLoopWrapperOp>(op) ||
59+ mlir::isa<mlir::omp::DistributeOp>(op) ||
60+ mlir::isa<mlir::omp::TaskloopOp>(op) ||
61+ mlir::isa<mlir::omp::TerminatorOp>(op)))
62+ return rewriter.notifyMatchFailure (op, " Op is part of a simd construct" );
4963
5064 if (!mlir::isa<mlir::func::FuncOp>(op->getParentOp ()) &&
5165 (mlir::isa<mlir::omp::TerminatorOp>(op) ||
@@ -67,6 +81,28 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
6781 LLVM_DEBUG (llvm::dbgs () << " SimdOnlyPass matched OpenMP op:\n " );
6882 LLVM_DEBUG (op->dump ());
6983
84+ auto eraseUnlessUsedBySimd = [&](mlir::Operation *ompOp,
85+ mlir::StringAttr name) {
86+ if (auto uses =
87+ mlir::SymbolTable::getSymbolUses (name, op->getParentOp ())) {
88+ for (auto &use : *uses)
89+ if (mlir::isa<mlir::omp::SimdOp>(use.getUser ()))
90+ return rewriter.notifyMatchFailure (op,
91+ " Op used by a simd construct" );
92+ }
93+ rewriter.eraseOp (ompOp);
94+ return mlir::success ();
95+ };
96+
97+ if (auto ompOp = mlir::dyn_cast<mlir::omp::PrivateClauseOp>(op))
98+ return eraseUnlessUsedBySimd (ompOp, ompOp.getSymNameAttr ());
99+ if (auto ompOp = mlir::dyn_cast<mlir::omp::DeclareReductionOp>(op))
100+ return eraseUnlessUsedBySimd (ompOp, ompOp.getSymNameAttr ());
101+ if (auto ompOp = mlir::dyn_cast<mlir::omp::CriticalDeclareOp>(op))
102+ return eraseUnlessUsedBySimd (ompOp, ompOp.getSymNameAttr ());
103+ if (auto ompOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(op))
104+ return eraseUnlessUsedBySimd (ompOp, ompOp.getSymNameAttr ());
105+
70106 // Erase ops that don't need any special handling
71107 if (mlir::isa<mlir::omp::BarrierOp>(op) ||
72108 mlir::isa<mlir::omp::FlushOp>(op) ||
@@ -87,75 +123,19 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
87123 fir::FirOpBuilder builder (rewriter, op);
88124 mlir::Location loc = op->getLoc ();
89125
90- auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool {
91- if (!ompOp)
92- return false ;
93-
94- llvm::SmallVector<std::pair<mlir::Value, mlir::BlockArgument>>
95- blockArgsPairs;
96- if (auto iface =
97- mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) {
98- iface.getBlockArgsPairs (blockArgsPairs);
99- for (auto [value, argument] : blockArgsPairs)
100- rewriter.replaceAllUsesWith (argument, value);
101- }
102-
103- if (ompOp->getRegion (0 ).getBlocks ().size () == 1 ) {
104- auto &block = *ompOp->getRegion (0 ).getBlocks ().begin ();
105- // This block is about to be removed so any arguments should have been
106- // replaced by now.
107- block.eraseArguments (0 , block.getNumArguments ());
108- if (auto terminatorOp =
109- mlir::dyn_cast<mlir::omp::TerminatorOp>(block.back ())) {
110- rewriter.eraseOp (terminatorOp);
111- }
112- rewriter.inlineBlockBefore (&block, op, {});
113- } else {
114- // When dealing with multi-block regions we need to fix up the control
115- // flow
116- auto *origBlock = ompOp->getBlock ();
117- auto *newBlock = rewriter.splitBlock (origBlock, ompOp->getIterator ());
118- auto *innerFrontBlock = &ompOp->getRegion (0 ).getBlocks ().front ();
119- builder.setInsertionPointToEnd (origBlock);
120- builder.create <mlir::cf::BranchOp>(loc, innerFrontBlock);
121- // We are no longer passing any arguments to the first block in the
122- // region, so this should be safe to erase.
123- innerFrontBlock->eraseArguments (0 , innerFrontBlock->getNumArguments ());
124-
125- for (auto &innerBlock : ompOp->getRegion (0 ).getBlocks ()) {
126- // Remove now-unused block arguments
127- for (auto arg : innerBlock.getArguments ()) {
128- if (arg.getUses ().empty ())
129- innerBlock.eraseArgument (arg.getArgNumber ());
130- }
131- if (auto terminatorOp =
132- mlir::dyn_cast<mlir::omp::TerminatorOp>(innerBlock.back ())) {
133- builder.setInsertionPointToEnd (&innerBlock);
134- builder.create <mlir::cf::BranchOp>(loc, newBlock);
135- rewriter.eraseOp (terminatorOp);
136- }
137- }
138-
139- rewriter.inlineRegionBefore (ompOp->getRegion (0 ), newBlock);
140- }
141-
142- rewriter.eraseOp (op);
143- return true ;
144- };
145-
146126 if (auto ompOp = mlir::dyn_cast<mlir::omp::LoopNestOp>(op)) {
147127 mlir::Type indexType = builder.getIndexType ();
148128 mlir::Type oldIndexType = ompOp.getIVs ().begin ()->getType ();
149129 builder.setInsertionPoint (op);
150- auto one = builder. create < mlir::arith::ConstantIndexOp>( loc, 1 );
130+ auto one = mlir::arith::ConstantIndexOp::create (builder, loc, 1 );
151131
152132 // Generate the new loop nest
153133 mlir::Block *nestBody = nullptr ;
154134 fir::DoLoopOp outerLoop = nullptr ;
155135 llvm::SmallVector<mlir::Value> loopIndArgs;
156136 for (auto extent : ompOp.getLoopUpperBounds ()) {
157137 auto ub = builder.createConvert (loc, indexType, extent);
158- auto doLoop = builder. create < fir::DoLoopOp>( loc, one, ub, one, false );
138+ auto doLoop = fir::DoLoopOp::create (builder, loc, one, ub, one, false );
159139 nestBody = doLoop.getBody ();
160140 builder.setInsertionPointToStart (nestBody);
161141 // Convert the indices to the type used inside the loop if needed
@@ -185,11 +165,12 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
185165 }
186166
187167 // Remove omp.yield at the end of the loop body
188- if (auto yieldOp = mlir::dyn_cast<mlir::omp::YieldOp>(nestBody->back ()))
168+ if (auto yieldOp =
169+ mlir::dyn_cast<mlir::omp::YieldOp>(nestBody->back ())) {
170+ assert (" omp.loop_nests's omp.yield has no operands" &&
171+ yieldOp->getNumOperands () == 0 );
189172 rewriter.eraseOp (yieldOp);
190- // DoLoopOp does not support multi-block regions, thus if we're dealing
191- // with multiple blocks we need to convert it into basic control-flow
192- // operations.
173+ }
193174 } else {
194175 rewriter.inlineRegionBefore (ompOp->getRegion (0 ), nestBody);
195176 auto indVarArg = outerLoop->getRegion (0 ).front ().getArgument (0 );
@@ -199,6 +180,9 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
199180 if (indVarArg.getType () != indexType)
200181 indVarArg.setType (indexType);
201182
183+ // fir.do_loop, unlike omp.loop_nest does not support multi-block
184+ // regions. If we're dealing with multiple blocks inside omp.loop_nest,
185+ // we need to convert it into basic control-flow operations instead.
202186 auto loopBlocks =
203187 fir::convertDoLoopToCFG (outerLoop, rewriter, false , false );
204188 auto *conditionalBlock = loopBlocks.first ;
@@ -237,7 +221,9 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
237221 if (auto yieldOp =
238222 mlir::dyn_cast<mlir::omp::YieldOp>(loopBlock->back ())) {
239223 builder.setInsertionPointToEnd (loopBlock);
240- builder.create <mlir::cf::BranchOp>(loc, lastBlock);
224+ mlir::cf::BranchOp::create (builder, loc, lastBlock);
225+ assert (" omp.loop_nests's omp.yield has no operands" &&
226+ yieldOp->getNumOperands () == 0 );
241227 rewriter.eraseOp (yieldOp);
242228 }
243229 }
@@ -255,16 +241,16 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
255241
256242 if (auto atomicReadOp = mlir::dyn_cast<mlir::omp::AtomicReadOp>(op)) {
257243 builder.setInsertionPoint (op);
258- auto loadOp = builder. create < fir::LoadOp>( loc, atomicReadOp.getX ());
259- auto storeOp = builder. create < fir::StoreOp>( loc, loadOp.getResult (),
260- atomicReadOp.getV ());
244+ auto loadOp = fir::LoadOp::create (builder, loc, atomicReadOp.getX ());
245+ auto storeOp = fir::StoreOp::create (builder, loc, loadOp.getResult (),
246+ atomicReadOp.getV ());
261247 rewriter.replaceOp (op, storeOp);
262248 return mlir::success ();
263249 }
264250
265251 if (auto atomicWriteOp = mlir::dyn_cast<mlir::omp::AtomicWriteOp>(op)) {
266- auto storeOp = builder. create < fir::StoreOp>( loc, atomicWriteOp.getExpr (),
267- atomicWriteOp.getX ());
252+ auto storeOp = fir::StoreOp::create (builder, loc, atomicWriteOp.getExpr (),
253+ atomicWriteOp.getX ());
268254 rewriter.replaceOp (op, storeOp);
269255 return mlir::success ();
270256 }
@@ -276,7 +262,7 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
276262 builder.setInsertionPointToStart (&block);
277263
278264 // Load the update `x` operand and replace its uses within the block
279- auto loadOp = builder. create < fir::LoadOp>( loc, atomicUpdateOp.getX ());
265+ auto loadOp = fir::LoadOp::create (builder, loc, atomicUpdateOp.getX ());
280266 rewriter.replaceUsesWithIf (
281267 block.getArgument (0 ), loadOp.getResult (),
282268 [&](auto &op) { return op.get ().getParentBlock () == █ });
@@ -286,14 +272,14 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
286272 auto yieldOp = mlir::cast<mlir::omp::YieldOp>(block.back ());
287273 assert (" only one yield operand" && yieldOp->getNumOperands () == 1 );
288274 builder.setInsertionPointAfter (yieldOp);
289- builder. create < fir::StoreOp>( loc, yieldOp->getOperand (0 ),
290- atomicUpdateOp.getX ());
275+ fir::StoreOp::create (builder, loc, yieldOp->getOperand (0 ),
276+ atomicUpdateOp.getX ());
291277 rewriter.eraseOp (yieldOp);
292278
293279 // Inline the final block and remove the now-empty op
294280 assert (" only one block argument" && block.getNumArguments () == 1 );
295281 block.eraseArguments (0 , block.getNumArguments ());
296- rewriter.inlineBlockBefore (&block, op , {});
282+ rewriter.inlineBlockBefore (&block, atomicUpdateOp , {});
297283 rewriter.eraseOp (op);
298284 return mlir::success ();
299285 }
@@ -305,6 +291,64 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
305291 return mlir::success ();
306292 }
307293
294+ auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool {
295+ if (!ompOp)
296+ return false ;
297+
298+ assert (" OpenMP operation has one region" && ompOp->getNumRegions () == 1 );
299+
300+ llvm::SmallVector<std::pair<mlir::Value, mlir::BlockArgument>>
301+ blockArgsPairs;
302+ if (auto iface =
303+ mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) {
304+ iface.getBlockArgsPairs (blockArgsPairs);
305+ for (auto [value, argument] : blockArgsPairs)
306+ rewriter.replaceAllUsesWith (argument, value);
307+ }
308+
309+ if (ompOp->getRegion (0 ).getBlocks ().size () == 1 ) {
310+ auto &block = *ompOp->getRegion (0 ).getBlocks ().begin ();
311+ // This block is about to be removed so any arguments should have been
312+ // replaced by now.
313+ block.eraseArguments (0 , block.getNumArguments ());
314+ if (auto terminatorOp =
315+ mlir::dyn_cast<mlir::omp::TerminatorOp>(block.back ())) {
316+ rewriter.eraseOp (terminatorOp);
317+ }
318+ rewriter.inlineBlockBefore (&block, ompOp, {});
319+ } else {
320+ // When dealing with multi-block regions we need to fix up the control
321+ // flow
322+ auto *origBlock = ompOp->getBlock ();
323+ auto *newBlock = rewriter.splitBlock (origBlock, ompOp->getIterator ());
324+ auto *innerFrontBlock = &ompOp->getRegion (0 ).getBlocks ().front ();
325+ builder.setInsertionPointToEnd (origBlock);
326+ mlir::cf::BranchOp::create (builder, loc, innerFrontBlock);
327+ // We are no longer passing any arguments to the first block in the
328+ // region, so this should be safe to erase.
329+ innerFrontBlock->eraseArguments (0 , innerFrontBlock->getNumArguments ());
330+
331+ for (auto &innerBlock : ompOp->getRegion (0 ).getBlocks ()) {
332+ // Remove now-unused block arguments
333+ for (auto arg : innerBlock.getArguments ()) {
334+ if (arg.getUses ().empty ())
335+ innerBlock.eraseArgument (arg.getArgNumber ());
336+ }
337+ if (auto terminatorOp =
338+ mlir::dyn_cast<mlir::omp::TerminatorOp>(innerBlock.back ())) {
339+ builder.setInsertionPointToEnd (&innerBlock);
340+ mlir::cf::BranchOp::create (builder, loc, newBlock);
341+ rewriter.eraseOp (terminatorOp);
342+ }
343+ }
344+
345+ rewriter.inlineRegionBefore (ompOp->getRegion (0 ), newBlock);
346+ }
347+
348+ rewriter.eraseOp (op);
349+ return true ;
350+ };
351+
308352 if (inlineSimpleOp (mlir::dyn_cast<mlir::omp::TeamsOp>(op)) ||
309353 inlineSimpleOp (mlir::dyn_cast<mlir::omp::ParallelOp>(op)) ||
310354 inlineSimpleOp (mlir::dyn_cast<mlir::omp::SingleOp>(op)) ||
@@ -324,7 +368,7 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
324368 inlineSimpleOp (mlir::dyn_cast<mlir::omp::MaskedOp>(op)))
325369 return mlir::success ();
326370
327- op->emitOpError (" OpenMP operation left unhandled after SimdOnly pass." );
371+ op->emitOpError (" left unhandled after SimdOnly pass." );
328372 return mlir::failure ();
329373 }
330374};
@@ -335,10 +379,7 @@ class SimdOnlyPass : public flangomp::impl::SimdOnlyPassBase<SimdOnlyPass> {
335379 SimdOnlyPass () = default ;
336380
337381 void runOnOperation () override {
338- mlir::func::FuncOp func = getOperation ();
339-
340- if (func.isDeclaration ())
341- return ;
382+ mlir::ModuleOp module = getOperation ();
342383
343384 mlir::MLIRContext *context = &getContext ();
344385 mlir::RewritePatternSet patterns (context);
@@ -350,8 +391,8 @@ class SimdOnlyPass : public flangomp::impl::SimdOnlyPassBase<SimdOnlyPass> {
350391 mlir::GreedySimplifyRegionLevel::Disabled);
351392
352393 if (mlir::failed (
353- mlir::applyPatternsGreedily (func , std::move (patterns), config))) {
354- mlir::emitError (func .getLoc (), " error in simd-only conversion pass" );
394+ mlir::applyPatternsGreedily (module , std::move (patterns), config))) {
395+ mlir::emitError (module .getLoc (), " error in simd-only conversion pass" );
355396 signalPassFailure ();
356397 }
357398 }
0 commit comments