Skip to content

Commit 6c8216e

Browse files
committed
rebase
Created using spr 1.3.7
2 parents f391970 + 6ca0a10 commit 6c8216e

File tree

30 files changed

+861
-146
lines changed

30 files changed

+861
-146
lines changed

clang/lib/Tooling/Transformer/RangeSelector.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ RangeSelector transformer::node(std::string ID) {
139139
(Node->get<Stmt>() != nullptr && Node->get<Expr>() == nullptr))
140140
? tooling::getExtendedRange(*Node, tok::TokenKind::semi,
141141
*Result.Context)
142-
: CharSourceRange::getTokenRange(Node->getSourceRange());
142+
: CharSourceRange::getTokenRange(
143+
Node->getSourceRange(/*IncludeQualifier=*/true));
143144
};
144145
}
145146

clang/unittests/Tooling/RangeSelectorTest.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,13 @@ TEST(RangeSelectorTest, NodeOpExpression) {
339339
EXPECT_THAT_EXPECTED(select(node("id"), Match), HasValue("3"));
340340
}
341341

342+
TEST(RangeSelectorTest, NodeOpTypeLoc) {
343+
StringRef Code = "namespace ns {struct Foo{};} ns::Foo a;";
344+
TestMatch Match =
345+
matchCode(Code, varDecl(hasTypeLoc(typeLoc().bind("typeloc"))));
346+
EXPECT_THAT_EXPECTED(select(node("typeloc"), Match), HasValue("ns::Foo"));
347+
}
348+
342349
TEST(RangeSelectorTest, StatementOp) {
343350
StringRef Code = "int f() { return 3; }";
344351
TestMatch Match = matchCode(Code, expr().bind("id"));

flang/include/flang/Lower/Support/ReductionProcessor.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ namespace omp {
4040

4141
class ReductionProcessor {
4242
public:
43+
using GenInitValueCBTy =
44+
std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
45+
mlir::Type type, mlir::Value ompOrig)>;
46+
using GenCombinerCBTy = std::function<void(
47+
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
48+
mlir::Value op1, mlir::Value op2, bool isByRef)>;
49+
4350
// TODO: Move this enumeration to the OpenMP dialect
4451
enum ReductionIdentifier {
4552
ID,
@@ -58,6 +65,9 @@ class ReductionProcessor {
5865
IEOR
5966
};
6067

68+
static bool doReductionByRef(mlir::Type reductionType);
69+
static bool doReductionByRef(mlir::Value reductionVar);
70+
6171
static ReductionIdentifier
6272
getReductionType(const omp::clause::ProcedureDesignator &pd);
6373

@@ -109,6 +119,14 @@ class ReductionProcessor {
109119
ReductionIdentifier redId,
110120
mlir::Type type, mlir::Value op1,
111121
mlir::Value op2);
122+
/// Creates an OpenMP reduction declaration and inserts it into the provided
123+
/// symbol table. The init and combiner regions are generated by the callback
124+
/// functions genCombinerCB and genInitValueCB.
125+
template <typename DeclareRedType>
126+
static DeclareRedType createDeclareReductionHelper(
127+
AbstractConverter &converter, llvm::StringRef reductionOpName,
128+
mlir::Type type, mlir::Location loc, bool isByRef,
129+
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
112130

113131
/// Creates an OpenMP reduction declaration and inserts it into the provided
114132
/// symbol table. The declaration has a constant initializer with the neutral

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ClauseProcessor.h"
1414
#include "Utils.h"
1515

16+
#include "flang/Lower/ConvertCall.h"
1617
#include "flang/Lower/ConvertExprToHLFIR.h"
1718
#include "flang/Lower/OpenMP/Clauses.h"
1819
#include "flang/Lower/PFTBuilder.h"
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
402403
return false;
403404
}
404405

406+
bool ClauseProcessor::processInitializer(
407+
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
408+
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
409+
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
410+
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
411+
mlir::Type type, mlir::Value ompOrig) {
412+
lower::SymMapScope scope(symMap);
413+
const parser::OmpInitializerExpression &iexpr = inp.v.v;
414+
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
415+
const std::list<parser::OmpStylizedDeclaration> &declList =
416+
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
417+
mlir::Value ompPrivVar;
418+
for (const parser::OmpStylizedDeclaration &decl : declList) {
419+
auto &name = std::get<parser::ObjectName>(decl.var.t);
420+
assert(name.symbol && "Name does not have a symbol");
421+
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
422+
fir::StoreOp::create(builder, loc, ompOrig, addr);
423+
fir::FortranVariableFlagsEnum extraFlags = {};
424+
fir::FortranVariableFlagsAttr attributes =
425+
Fortran::lower::translateSymbolAttributes(builder.getContext(),
426+
*name.symbol, extraFlags);
427+
auto declareOp = hlfir::DeclareOp::create(
428+
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
429+
0, attributes);
430+
if (name.ToString() == "omp_priv")
431+
ompPrivVar = declareOp.getResult(0);
432+
symMap.addVariableDefinition(*name.symbol, declareOp);
433+
}
434+
// Lower the expression/function call
435+
lower::StatementContext stmtCtx;
436+
mlir::Value result = common::visit(
437+
common::visitors{
438+
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
439+
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
440+
symMap, stmtCtx);
441+
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
442+
return privVal;
443+
},
444+
[&](const auto &expr) -> mlir::Value {
445+
mlir::Value exprResult = fir::getBase(convertExprToValue(
446+
loc, converter, clause->v, symMap, stmtCtx));
447+
// Conversion can either give a value or a refrence to a value,
448+
// we need to return the reduction type, so an optional load may
449+
// be generated.
450+
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
451+
exprResult.getType()))
452+
if (ompPrivVar.getType() == refType)
453+
exprResult = fir::LoadOp::create(builder, loc, exprResult);
454+
return exprResult;
455+
}},
456+
clause->v.u);
457+
stmtCtx.finalizeAndPop();
458+
return result;
459+
};
460+
return true;
461+
}
462+
return false;
463+
}
464+
405465
bool ClauseProcessor::processMergeable(
406466
mlir::omp::MergeableClauseOps &result) const {
407467
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "flang/Lower/Bridge.h"
1919
#include "flang/Lower/DirectivesCommon.h"
2020
#include "flang/Lower/OpenMP/Clauses.h"
21+
#include "flang/Lower/Support/ReductionProcessor.h"
2122
#include "flang/Optimizer/Builder/Todo.h"
2223
#include "flang/Parser/dump-parse-tree.h"
2324
#include "flang/Parser/parse-tree.h"
@@ -88,6 +89,9 @@ class ClauseProcessor {
8889
bool processHint(mlir::omp::HintClauseOps &result) const;
8990
bool processInclusive(mlir::Location currentLocation,
9091
mlir::omp::InclusiveClauseOps &result) const;
92+
bool processInitializer(
93+
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
94+
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
9195
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
9296
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
9397
bool processNowait(mlir::omp::NowaitClauseOps &result) const;

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
981981

982982
Initializer make(const parser::OmpClause::Initializer &inp,
983983
semantics::SemanticsContext &semaCtx) {
984-
llvm_unreachable("Empty: initializer");
984+
const parser::OmpInitializerExpression &iexpr = inp.v.v;
985+
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
986+
const parser::OmpStylizedInstance::Instance &instance =
987+
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
988+
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
989+
auto &expr = std::get<parser::Expr>(as->t);
990+
return Initializer{makeExpr(expr, semaCtx)};
991+
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
992+
if (call->typedCall) {
993+
const auto &procRef = *call->typedCall;
994+
semantics::SomeExpr evalProcRef{procRef};
995+
return Initializer{evalProcRef};
996+
}
997+
}
998+
999+
llvm_unreachable("Unexpected initializer");
9851000
}
9861001

9871002
InReduction make(const parser::OmpClause::InReduction &inp,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
#include "Decomposer.h"
1919
#include "Utils.h"
2020
#include "flang/Common/idioms.h"
21+
#include "flang/Evaluate/type.h"
2122
#include "flang/Lower/Bridge.h"
2223
#include "flang/Lower/ConvertExpr.h"
24+
#include "flang/Lower/ConvertExprToHLFIR.h"
2325
#include "flang/Lower/ConvertVariable.h"
2426
#include "flang/Lower/DirectivesCommon.h"
2527
#include "flang/Lower/OpenMP/Clauses.h"
2628
#include "flang/Lower/StatementContext.h"
29+
#include "flang/Lower/Support/ReductionProcessor.h"
2730
#include "flang/Lower/SymbolMap.h"
2831
#include "flang/Optimizer/Builder/BoxValue.h"
2932
#include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
28472850
// TODO: Add private syms and vars.
28482851
args.reduction.syms = reductionSyms;
28492852
args.reduction.vars = clauseOps.reductionVars;
2850-
28512853
return genOpWithBody<mlir::omp::TeamsOp>(
28522854
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
28532855
llvm::omp::Directive::OMPD_teams)
@@ -3570,12 +3572,156 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35703572
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
35713573
}
35723574

3575+
static ReductionProcessor::GenCombinerCBTy
3576+
processReductionCombiner(lower::AbstractConverter &converter,
3577+
lower::SymMap &symTable,
3578+
semantics::SemanticsContext &semaCtx,
3579+
const parser::OmpReductionSpecifier &specifier) {
3580+
ReductionProcessor::GenCombinerCBTy genCombinerCB;
3581+
const auto &combinerExpression =
3582+
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
3583+
.value();
3584+
const parser::OmpStylizedInstance &combinerInstance =
3585+
combinerExpression.v.front();
3586+
const parser::OmpStylizedInstance::Instance &instance =
3587+
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
3588+
3589+
const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
3590+
if (!as) {
3591+
TODO(converter.getCurrentLocation(),
3592+
"A combiner that is a subroutine call is not yet supported");
3593+
}
3594+
auto &expr = std::get<parser::Expr>(as->t);
3595+
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3596+
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3597+
bool isByRef) {
3598+
const auto &evalExpr = makeExpr(expr, semaCtx);
3599+
lower::SymMapScope scope(symTable);
3600+
const std::list<parser::OmpStylizedDeclaration> &declList =
3601+
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
3602+
for (const parser::OmpStylizedDeclaration &decl : declList) {
3603+
auto &name = std::get<parser::ObjectName>(decl.var.t);
3604+
mlir::Value addr = lhs;
3605+
mlir::Type type = lhs.getType();
3606+
bool isRhs = name.ToString() == std::string("omp_in");
3607+
if (isRhs) {
3608+
addr = rhs;
3609+
type = rhs.getType();
3610+
}
3611+
3612+
assert(name.symbol && "Reduction object name does not have a symbol");
3613+
if (!fir::conformsWithPassByRef(type)) {
3614+
addr = builder.createTemporary(loc, type);
3615+
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
3616+
}
3617+
fir::FortranVariableFlagsEnum extraFlags = {};
3618+
fir::FortranVariableFlagsAttr attributes =
3619+
Fortran::lower::translateSymbolAttributes(builder.getContext(),
3620+
*name.symbol, extraFlags);
3621+
auto declareOp =
3622+
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
3623+
{}, nullptr, nullptr, 0, attributes);
3624+
symTable.addVariableDefinition(*name.symbol, declareOp);
3625+
}
3626+
3627+
lower::StatementContext stmtCtx;
3628+
mlir::Value result = fir::getBase(
3629+
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
3630+
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
3631+
if (lhs.getType() == refType.getElementType())
3632+
result = fir::LoadOp::create(builder, loc, result);
3633+
stmtCtx.finalizeAndPop();
3634+
if (isByRef) {
3635+
fir::StoreOp::create(builder, loc, result, lhs);
3636+
mlir::omp::YieldOp::create(builder, loc, lhs);
3637+
} else {
3638+
mlir::omp::YieldOp::create(builder, loc, result);
3639+
}
3640+
};
3641+
return genCombinerCB;
3642+
}
3643+
3644+
// Checks that the reduction type is either a trivial type or a derived type of
3645+
// trivial types.
3646+
static bool isSimpleReductionType(mlir::Type reductionType) {
3647+
if (fir::isa_trivial(reductionType))
3648+
return true;
3649+
if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) {
3650+
for (auto [_, fieldType] : recordTy.getTypeList()) {
3651+
if (!fir::isa_trivial(fieldType))
3652+
return false;
3653+
}
3654+
}
3655+
return true;
3656+
}
3657+
3658+
// Getting the type from a symbol compared to a DeclSpec is simpler since we do
3659+
// not need to consider derived vs intrinsic types. Semantics is guaranteed to
3660+
// generate these symbols.
3661+
static mlir::Type
3662+
getReductionType(lower::AbstractConverter &converter,
3663+
const parser::OmpReductionSpecifier &specifier) {
3664+
const auto &combinerExpression =
3665+
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
3666+
.value();
3667+
const parser::OmpStylizedInstance &combinerInstance =
3668+
combinerExpression.v.front();
3669+
const std::list<parser::OmpStylizedDeclaration> &declList =
3670+
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
3671+
const parser::OmpStylizedDeclaration &decl = declList.front();
3672+
const auto &name = std::get<parser::ObjectName>(decl.var.t);
3673+
const auto &symbol = semantics::SymbolRef(*name.symbol);
3674+
mlir::Type reductionType = converter.genType(symbol);
3675+
3676+
if (!isSimpleReductionType(reductionType))
3677+
TODO(converter.getCurrentLocation(),
3678+
"declare reduction currently only supports trival types or derived "
3679+
"types containing trivial types");
3680+
return reductionType;
3681+
}
3682+
35733683
static void genOMP(
35743684
lower::AbstractConverter &converter, lower::SymMap &symTable,
35753685
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
35763686
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
3577-
if (!semaCtx.langOptions().OpenMPSimd)
3578-
TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
3687+
if (semaCtx.langOptions().OpenMPSimd)
3688+
return;
3689+
3690+
const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
3691+
const parser::OmpArgument &arg{args.v.front()};
3692+
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
3693+
3694+
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
3695+
TODO(converter.getCurrentLocation(),
3696+
"multiple types in declare reduction is not yet supported");
3697+
3698+
mlir::Type reductionType = getReductionType(converter, specifier);
3699+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
3700+
processReductionCombiner(converter, symTable, semaCtx, specifier);
3701+
const parser::OmpClauseList &initializer =
3702+
declareReductionConstruct.v.Clauses();
3703+
if (initializer.v.size() > 0) {
3704+
List<Clause> clauses = makeClauses(initializer, semaCtx);
3705+
ReductionProcessor::GenInitValueCBTy genInitValueCB;
3706+
ClauseProcessor cp(converter, semaCtx, clauses);
3707+
const parser::OmpClause::Initializer &iclause{
3708+
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
3709+
cp.processInitializer(symTable, iclause, genInitValueCB);
3710+
const auto &identifier =
3711+
std::get<parser::OmpReductionIdentifier>(specifier.t);
3712+
const auto &designator =
3713+
std::get<parser::ProcedureDesignator>(identifier.u);
3714+
const auto &reductionName = std::get<parser::Name>(designator.u);
3715+
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
3716+
ReductionProcessor::createDeclareReductionHelper<
3717+
mlir::omp::DeclareReductionOp>(
3718+
converter, reductionName.ToString(), reductionType,
3719+
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
3720+
} else {
3721+
TODO(converter.getCurrentLocation(),
3722+
"declare reduction without an initializer clause is not yet "
3723+
"supported");
3724+
}
35793725
}
35803726

35813727
static void

0 commit comments

Comments
 (0)