Skip to content

Commit 66ee96f

Browse files
authored
Move reduction lowering from DistributeOp to TeamsOp and use teams reduction clauses to generate info. (#159)
* Move teams reductions from distribute to teams and use the reduction clause for the teams directive to create the reduction information. * Remove composite matching framework since this is no longer needed with the new teams reduction implementation.
1 parent 71a0d75 commit 66ee96f

File tree

14 files changed

+212
-350
lines changed

14 files changed

+212
-350
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,7 +1660,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
16601660
return;
16611661

16621662
bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
1663-
bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
16641663
bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);
16651664

16661665
ASTContext &C = CGM.getContext();
@@ -1756,7 +1755,7 @@ void CGOpenMPRuntimeGPU::emitReduction(
17561755

17571756
CGF.Builder.restoreIP(OMPBuilder.createReductionsGPU(
17581757
OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
1759-
DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
1758+
llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
17601759
CGF.getTarget().getGridValue(), C.getLangOpts().OpenMPCUDAReductionBufNum,
17611760
RTLoc));
17621761
return;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,6 @@ genReductionVars(mlir::Operation *op, lower::AbstractConverter &converter,
796796

797797
mlir::Block *entryBlock = firOpBuilder.createBlock(
798798
&op->getRegion(0), {}, reductionTypes, blockArgLocs);
799-
800799
// Bind the reduction arguments to their block arguments.
801800
for (auto [arg, prv] :
802801
llvm::zip_equal(reductionArgs, entryBlock->getArguments())) {
@@ -1659,14 +1658,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
16591658
loc, llvm::omp::Directive::OMPD_taskwait);
16601659
}
16611660

1662-
static void
1663-
genTeamsClauses(lower::AbstractConverter &converter,
1664-
semantics::SemanticsContext &semaCtx,
1665-
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1666-
mlir::Location loc, bool evalOutsideTarget,
1667-
mlir::omp::TeamsOperands &clauseOps,
1668-
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
1669-
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps) {
1661+
static void genTeamsClauses(
1662+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1663+
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1664+
mlir::Location loc, bool evalOutsideTarget,
1665+
mlir::omp::TeamsOperands &clauseOps,
1666+
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
1667+
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps,
1668+
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
1669+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
16701670
ClauseProcessor cp(converter, semaCtx, clauses);
16711671
cp.processAllocate(clauseOps);
16721672
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
@@ -1684,8 +1684,7 @@ genTeamsClauses(lower::AbstractConverter &converter,
16841684
cp.processNumTeams(stmtCtx, numTeamsClauseOps);
16851685
cp.processThreadLimit(stmtCtx, threadLimitClauseOps);
16861686
}
1687-
1688-
// cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
1687+
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
16891688
}
16901689

16911690
static void genWsloopClauses(
@@ -1874,7 +1873,6 @@ static mlir::omp::ParallelOp genParallelOp(
18741873
llvm::ArrayRef<mlir::Type> reductionTypes, DataSharingProcessor *dsp,
18751874
bool isComposite = false, mlir::omp::TargetOp parentTarget = nullptr) {
18761875
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1877-
18781876
auto reductionCallback = [&](mlir::Operation *op) {
18791877
genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
18801878
return llvm::SmallVector<const semantics::Symbol *>(reductionSyms);
@@ -2360,14 +2358,22 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23602358
mlir::omp::TeamsOperands clauseOps;
23612359
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
23622360
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
2361+
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
2362+
llvm::SmallVector<mlir::Type> reductionTypes;
23632363
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
23642364
evalOutsideTarget, clauseOps, numTeamsClauseOps,
2365-
threadLimitClauseOps);
2365+
threadLimitClauseOps, reductionTypes, reductionSyms);
2366+
2367+
auto reductionCallback = [&](mlir::Operation *op) {
2368+
genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
2369+
return llvm::SmallVector<const semantics::Symbol *>(reductionSyms);
2370+
};
23662371

23672372
auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
23682373
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
23692374
llvm::omp::Directive::OMPD_teams)
2370-
.setClauses(&item->clauses),
2375+
.setClauses(&item->clauses)
2376+
.setGenRegionEntryCb(reductionCallback),
23712377
queue, item, clauseOps);
23722378

23732379
if (numTeamsClauseOps.numTeamsUpper) {
@@ -2436,7 +2442,6 @@ static void genStandaloneDo(lower::AbstractConverter &converter,
24362442
const ConstructQueue &queue,
24372443
ConstructQueue::const_iterator item) {
24382444
lower::StatementContext stmtCtx;
2439-
24402445
mlir::omp::WsloopOperands wsloopClauseOps;
24412446
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
24422447
llvm::SmallVector<mlir::Type> reductionTypes;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s
2+
! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s
3+
4+
! CHECK: omp.teams
5+
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}} -> %{{.*}} : !fir.ref<i32>)
6+
subroutine myfun()
7+
integer :: i, j
8+
i = 0
9+
j = 0
10+
!$omp target teams distribute parallel do reduction(+:i)
11+
do j = 1,5
12+
i = i + j
13+
end do
14+
!$omp end target teams distribute parallel do
15+
end subroutine myfun

flang/test/Lower/OpenMP/Todo/reduction-teams.f90 renamed to flang/test/Lower/OpenMP/reduction-teams.f90

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s
22
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s
3-
! XFAIL: *
43

54
! CHECK: omp.teams
65
! CHECK-SAME: reduction

flang/test/Lower/OpenMP/sections-array-reduction.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ subroutine sectionsReduction(x)
3535
! CHECK: omp.parallel {
3636
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
3737
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
38-
! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 -> %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
38+
! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
3939
! CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):
4040
! CHECK: omp.section {
4141
! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):

flang/test/Lower/OpenMP/sections-reduction.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ subroutine sectionsReduction(x,y)
4040
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
4141
! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
4242
! CHECK: omp.parallel {
43-
! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref<f32>, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref<f32>) {
43+
! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[ARG_0:.*]] : !fir.ref<f32>, @add_reduction_f32 %[[VAL_4]]#0 -> %[[ARG_1:.*]] : !fir.ref<f32>) {
4444
! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<f32>, %[[VAL_6:.*]]: !fir.ref<f32>):
4545
! CHECK: omp.section {
4646
! CHECK: ^bb0(%[[VAL_7:.*]]: !fir.ref<f32>, %[[VAL_8:.*]]: !fir.ref<f32>):
@@ -71,7 +71,7 @@ subroutine sectionsReduction(x,y)
7171
! CHECK: omp.terminator
7272
! CHECK: }
7373
! CHECK: omp.parallel {
74-
! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref<f32>, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref<f32>) {
74+
! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[ARG_2:.*]] : !fir.ref<f32>, @add_reduction_f32 %[[VAL_4]]#0 -> %[[ARG_3:.*]] : !fir.ref<f32>) {
7575
! CHECK: ^bb0(%[[VAL_23:.*]]: !fir.ref<f32>, %[[VAL_24:.*]]: !fir.ref<f32>):
7676
! CHECK: omp.section {
7777
! CHECK: ^bb0(%[[VAL_25:.*]]: !fir.ref<f32>, %[[VAL_26:.*]]: !fir.ref<f32>):

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,8 +1844,6 @@ class OpenMPIRBuilder {
18441844
/// nowait.
18451845
/// \param IsTeamsReduction Optional flag set if it is a teams
18461846
/// reduction.
1847-
/// \param HasDistribute Optional flag set if it is a
1848-
/// distribute reduction.
18491847
/// \param GridValue Optional GPU grid value.
18501848
/// \param ReductionBufNum Optional OpenMPCUDAReductionBufNumValue to be
18511849
/// used for teams reduction.
@@ -1854,7 +1852,6 @@ class OpenMPIRBuilder {
18541852
const LocationDescription &Loc, InsertPointTy AllocaIP,
18551853
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
18561854
bool IsNoWait = false, bool IsTeamsReduction = false,
1857-
bool HasDistribute = false,
18581855
ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
18591856
std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
18601857
Value *SrcLocInfo = nullptr);
@@ -1926,8 +1923,7 @@ class OpenMPIRBuilder {
19261923
InsertPointTy AllocaIP,
19271924
ArrayRef<ReductionInfo> ReductionInfos,
19281925
ArrayRef<bool> IsByRef, bool IsNoWait = false,
1929-
bool IsTeamsReduction = false,
1930-
bool HasDistribute = false);
1926+
bool IsTeamsReduction = false);
19311927

19321928
///}
19331929

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3412,9 +3412,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
34123412
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
34133413
const LocationDescription &Loc, InsertPointTy AllocaIP,
34143414
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3415-
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
3416-
ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3417-
unsigned ReductionBufNum, Value *SrcLocInfo) {
3415+
bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
3416+
std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
3417+
Value *SrcLocInfo) {
34183418
if (!updateToLocation(Loc))
34193419
return InsertPointTy();
34203420
Builder.restoreIP(CodeGenIP);
@@ -3590,13 +3590,11 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
35903590
ReductionFunc;
35913591
});
35923592
} else {
3593-
if (!HasDistribute || IsTeamsReduction) {
3594-
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
3595-
Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
3596-
Value *Reduced;
3597-
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3598-
Builder.CreateStore(Reduced, LHS, false);
3599-
}
3593+
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
3594+
Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
3595+
Value *Reduced;
3596+
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3597+
Builder.CreateStore(Reduced, LHS, false);
36003598
}
36013599
}
36023600
emitBlock(ExitBB, CurFunc);
@@ -3685,11 +3683,11 @@ static void populateReductionFunction(
36853683
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
36863684
const LocationDescription &Loc, InsertPointTy AllocaIP,
36873685
ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
3688-
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute) {
3686+
bool IsNoWait, bool IsTeamsReduction) {
36893687
assert(ReductionInfos.size() == IsByRef.size());
36903688
if (Config.isGPU())
36913689
return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
3692-
IsNoWait, IsTeamsReduction, HasDistribute);
3690+
IsNoWait, IsTeamsReduction);
36933691

36943692
checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
36953693

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -472,16 +472,20 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472
//===----------------------------------------------------------------------===//
473473

474474
static ParseResult parseClauseWithRegionArgs(
475-
OpAsmParser &parser, Region &region,
475+
OpAsmParser &parser,
476476
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
477477
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols,
478-
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
478+
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
479+
bool parseParens = true) {
479480
SmallVector<SymbolRefAttr> reductionVec;
480481
SmallVector<bool> isByRefVec;
481482
unsigned regionArgOffset = regionPrivateArgs.size();
482483

484+
OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren
485+
: OpAsmParser::Delimiter::None;
486+
483487
if (failed(
484-
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
488+
parser.parseCommaSeparatedList(delimiter, [&]() {
485489
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
486490
if (parser.parseAttribute(reductionVec.emplace_back()) ||
487491
parser.parseOperand(operands.emplace_back()) ||
@@ -536,15 +540,15 @@ static ParseResult parseParallelRegion(
536540
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
537541

538542
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
539-
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
543+
if (failed(parseClauseWithRegionArgs(parser, reductionVars,
540544
reductionTypes, reductionByref,
541545
reductionSyms, regionPrivateArgs)))
542546
return failure();
543547
}
544548

545549
if (succeeded(parser.parseOptionalKeyword("private"))) {
546550
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
547-
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
551+
if (failed(parseClauseWithRegionArgs(parser, privateVars,
548552
privateTypes, privateByref,
549553
privateSyms, regionPrivateArgs)))
550554
return failure();
@@ -597,48 +601,26 @@ static ParseResult parseReductionVarList(
597601
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
598602
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
599603
ArrayAttr &reductionSyms) {
600-
SmallVector<SymbolRefAttr> reductionVec;
601-
SmallVector<bool> isByRefVec;
602-
if (failed(parser.parseCommaSeparatedList([&]() {
603-
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
604-
if (parser.parseAttribute(reductionVec.emplace_back()) ||
605-
parser.parseArrow() ||
606-
parser.parseOperand(reductionVars.emplace_back()) ||
607-
parser.parseColonType(reductionTypes.emplace_back()))
608-
return failure();
609-
isByRefVec.push_back(optionalByref.succeeded());
610-
return success();
611-
})))
612-
return failure();
613-
reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
614-
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
615-
reductionSyms = ArrayAttr::get(parser.getContext(), reductions);
616-
return success();
604+
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
605+
return parseClauseWithRegionArgs(parser, reductionVars, reductionTypes,
606+
reductionByref, reductionSyms,
607+
regionPrivateArgs, /*parseParens=*/false);
617608
}
618609

619610
/// Print Reduction clause
620-
static void
621-
printReductionVarList(OpAsmPrinter &p, Operation *op,
622-
OperandRange reductionVars, TypeRange reductionTypes,
623-
std::optional<DenseBoolArrayAttr> reductionByref,
624-
std::optional<ArrayAttr> reductionSyms) {
625-
auto getByRef = [&](unsigned i) -> const char * {
626-
if (!reductionByref || !*reductionByref)
627-
return "";
628-
assert(reductionByref->empty() || i < reductionByref->size());
629-
if (!reductionByref->empty() && (*reductionByref)[i])
630-
return "byref ";
631-
return "";
632-
};
633-
634-
for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
635-
if (i != 0)
636-
p << ", ";
637-
p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i]
638-
<< " : " << reductionVars[i].getType();
611+
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
612+
OperandRange reductionVars,
613+
TypeRange reductionTypes,
614+
DenseBoolArrayAttr reductionByref,
615+
ArrayAttr reductionSyms) {
616+
if (reductionSyms) {
617+
auto *argsBegin = op->getRegion(0).front().getArguments().begin();
618+
MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
619+
printClauseWithRegionArgs(p, op, argsSubrange, llvm::StringRef(),
620+
reductionVars, reductionTypes, reductionByref,
621+
reductionSyms);
639622
}
640623
}
641-
642624
/// Verifies Reduction Clause
643625
static LogicalResult
644626
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
@@ -1824,7 +1806,7 @@ parseWsloop(OpAsmParser &parser, Region &region,
18241806
// Parse an optional reduction clause
18251807
llvm::SmallVector<OpAsmParser::Argument> privates;
18261808
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
1827-
if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
1809+
if (failed(parseClauseWithRegionArgs(parser, reductionOperands,
18281810
reductionTypes, reductionByRef,
18291811
reductionSymbols, privates)))
18301812
return failure();

0 commit comments

Comments
 (0)