Skip to content

Commit 2236526

Browse files
[NFC][Flang][Pass] Add SUM intrinsic operations inside the workshare construct
1 parent 52b7141 commit 2236526

File tree

2 files changed

+117
-35
lines changed

2 files changed

+117
-35
lines changed

flang/lib/Optimizer/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_flang_library(FlangOpenMPTransforms
2121
FortranCommon
2222
MLIRFuncDialect
2323
MLIROpenMPDialect
24+
MLIRArithDialect
2425
HLFIRDialect
2526
MLIRIR
2627
MLIRPass

flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp

Lines changed: 116 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//
1717
//===----------------------------------------------------------------------===//
1818

19+
#include "flang/Optimizer/Builder/HLFIRTools.h"
1920
#include <flang/Optimizer/Builder/FIRBuilder.h>
2021
#include <flang/Optimizer/Dialect/FIROps.h>
2122
#include <flang/Optimizer/Dialect/FIRType.h>
@@ -335,49 +336,129 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
335336
for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
336337
bool isLast = i + 1 == regions.size();
337338
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
338-
OpBuilder singleBuilder(sourceRegion.getContext());
339-
Block *singleBlock = new Block();
340-
singleBuilder.setInsertionPointToStart(singleBlock);
341-
342339
OpBuilder allocaBuilder(sourceRegion.getContext());
343340
Block *allocaBlock = new Block();
344341
allocaBuilder.setInsertionPointToStart(allocaBlock);
345342

346-
OpBuilder parallelBuilder(sourceRegion.getContext());
347-
Block *parallelBlock = new Block();
348-
parallelBuilder.setInsertionPointToStart(parallelBlock);
349-
350-
auto [allParallelized, copyprivateVars] =
351-
moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
352-
singleBuilder, parallelBuilder);
353-
if (allParallelized) {
354-
// The single region was not required as all operations were safe to
355-
// parallelize
356-
assert(copyprivateVars.empty());
357-
assert(allocaBlock->empty());
358-
delete singleBlock;
343+
it = block.begin();
344+
while (&*it != terminator)
345+
if (isa<hlfir::SumOp>(it))
346+
break;
347+
else
348+
it++;
349+
350+
if (auto sumOp = dyn_cast<hlfir::SumOp>(it)) {
351+
/// Implementation:
352+
/// Intrinsic function `SUM` operations
353+
/// --
354+
/// x = sum(array)
355+
///
356+
/// is converted to
357+
///
358+
/// !$omp parallel do
359+
/// do i = 1, size(array)
360+
/// x = x + array(i)
361+
/// end do
362+
/// !$omp end parallel do
363+
364+
OpBuilder wslBuilder(sourceRegion.getContext());
365+
Block *wslBlock = new Block();
366+
wslBuilder.setInsertionPointToStart(wslBlock);
367+
368+
Value target = dyn_cast<hlfir::AssignOp>(++it).getLhs();
369+
Value array = sumOp.getArray();
370+
Value dim = sumOp.getDim();
371+
fir::SequenceType arrayTy = dyn_cast<fir::SequenceType>(
372+
hlfir::getFortranElementOrSequenceType(array.getType()));
373+
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
374+
if (arrayShape.size() == 1 && !dim) {
375+
Value itr = allocaBuilder.create<fir::AllocaOp>(
376+
loc, allocaBuilder.getI64Type());
377+
Value c_one = allocaBuilder.create<arith::ConstantOp>(
378+
loc, allocaBuilder.getI64IntegerAttr(1));
379+
Value c_arr_size = allocaBuilder.create<arith::ConstantOp>(
380+
loc, allocaBuilder.getI64IntegerAttr(arrayShape[0]));
381+
// Value c_zero = allocaBuilder.create<arith::ConstantOp>(loc,
382+
// allocaBuilder.getZeroAttr(arrayTy.getEleTy()));
383+
// allocaBuilder.create<fir::StoreOp>(loc, c_zero, target);
384+
385+
omp::WsloopOperands wslOps;
386+
omp::WsloopOp wslOp =
387+
rootBuilder.create<omp::WsloopOp>(loc, wslOps);
388+
389+
hlfir::LoopNest ln;
390+
ln.outerOp = wslOp;
391+
omp::LoopNestOperands lnOps;
392+
lnOps.loopLowerBounds.push_back(c_one);
393+
lnOps.loopUpperBounds.push_back(c_arr_size);
394+
lnOps.loopSteps.push_back(c_one);
395+
lnOps.loopInclusive = wslBuilder.getUnitAttr();
396+
omp::LoopNestOp lnOp =
397+
wslBuilder.create<omp::LoopNestOp>(loc, lnOps);
398+
Block *lnBlock = wslBuilder.createBlock(&lnOp.getRegion());
399+
lnBlock->addArgument(c_one.getType(), loc);
400+
wslBuilder.create<fir::StoreOp>(
401+
loc, lnOp.getRegion().getArgument(0), itr);
402+
Value tarLoad = wslBuilder.create<fir::LoadOp>(loc, target);
403+
Value itrLoad = wslBuilder.create<fir::LoadOp>(loc, itr);
404+
hlfir::DesignateOp arrDesOp = wslBuilder.create<hlfir::DesignateOp>(
405+
loc, fir::ReferenceType::get(arrayTy.getEleTy()), array,
406+
itrLoad);
407+
Value desLoad = wslBuilder.create<fir::LoadOp>(loc, arrDesOp);
408+
Value addf =
409+
wslBuilder.create<arith::AddFOp>(loc, tarLoad, desLoad);
410+
wslBuilder.create<fir::StoreOp>(loc, addf, target);
411+
wslBuilder.create<omp::YieldOp>(loc);
412+
ln.body = lnBlock;
413+
wslOp.getRegion().push_back(wslBlock);
414+
targetRegion.front().getOperations().splice(
415+
wslOp->getIterator(), allocaBlock->getOperations());
416+
} else {
417+
emitError(loc, "Only 1D array scalar assignment for sum "
418+
"instrinsic is supported in workshare construct");
419+
return;
420+
}
359421
} else {
360-
omp::SingleOperands singleOperands;
361-
if (isLast)
362-
singleOperands.nowait = rootBuilder.getUnitAttr();
363-
singleOperands.copyprivateVars = copyprivateVars;
364-
cleanupBlock(singleBlock);
365-
for (auto var : singleOperands.copyprivateVars) {
366-
mlir::func::FuncOp funcOp =
367-
createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
368-
singleOperands.copyprivateSyms.push_back(
369-
SymbolRefAttr::get(funcOp));
422+
OpBuilder singleBuilder(sourceRegion.getContext());
423+
Block *singleBlock = new Block();
424+
singleBuilder.setInsertionPointToStart(singleBlock);
425+
426+
OpBuilder parallelBuilder(sourceRegion.getContext());
427+
Block *parallelBlock = new Block();
428+
parallelBuilder.setInsertionPointToStart(parallelBlock);
429+
430+
auto [allParallelized, copyprivateVars] =
431+
moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
432+
singleBuilder, parallelBuilder);
433+
if (allParallelized) {
434+
// The single region was not required as all operations were safe to
435+
// parallelize
436+
assert(copyprivateVars.empty());
437+
assert(allocaBlock->empty());
438+
delete singleBlock;
439+
} else {
440+
omp::SingleOperands singleOperands;
441+
if (isLast)
442+
singleOperands.nowait = rootBuilder.getUnitAttr();
443+
singleOperands.copyprivateVars = copyprivateVars;
444+
cleanupBlock(singleBlock);
445+
for (auto var : singleOperands.copyprivateVars) {
446+
mlir::func::FuncOp funcOp =
447+
createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
448+
singleOperands.copyprivateSyms.push_back(
449+
SymbolRefAttr::get(funcOp));
450+
}
451+
omp::SingleOp singleOp =
452+
rootBuilder.create<omp::SingleOp>(loc, singleOperands);
453+
singleOp.getRegion().push_back(singleBlock);
454+
targetRegion.front().getOperations().splice(
455+
singleOp->getIterator(), allocaBlock->getOperations());
370456
}
371-
omp::SingleOp singleOp =
372-
rootBuilder.create<omp::SingleOp>(loc, singleOperands);
373-
singleOp.getRegion().push_back(singleBlock);
374-
targetRegion.front().getOperations().splice(
375-
singleOp->getIterator(), allocaBlock->getOperations());
457+
rootBuilder.getInsertionBlock()->getOperations().splice(
458+
rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
459+
delete parallelBlock;
376460
}
377-
rootBuilder.getInsertionBlock()->getOperations().splice(
378-
rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
379461
delete allocaBlock;
380-
delete parallelBlock;
381462
} else {
382463
auto op = std::get<Operation *>(opOrSingle);
383464
if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {

0 commit comments

Comments
 (0)