99// Emit OpenACC Stmt nodes as CIR code.
1010//
1111// ===----------------------------------------------------------------------===//
12+ #include < type_traits>
1213
1314#include " CIRGenBuilder.h"
1415#include " CIRGenFunction.h"
@@ -23,22 +24,39 @@ using namespace cir;
2324using namespace mlir ::acc;
2425
2526namespace {
27+ // Simple type-trait to see if the first template arg is one of the list, so we
28+ // can tell whether to `if-constexpr` a bunch of stuff.
29+ template <typename ToTest, typename T, typename ... Tys>
30+ constexpr bool isOneOfTypes =
31+ std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
32+ template <typename ToTest, typename T>
33+ constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
34+
2635class OpenACCClauseCIREmitter final
2736 : public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
2837 CIRGenModule &cgm;
38+ // This is necessary since a few of the clauses emit differently based on the
39+ // directive kind they are attached to.
40+ OpenACCDirectiveKind dirKind;
41+ SourceLocation dirLoc;
2942
3043 struct AttributeData {
3144 // Value of the 'default' attribute, added on 'data' and 'compute'/etc
3245 // constructs as a 'default-attr'.
3346 std::optional<ClauseDefaultValue> defaultVal = std::nullopt ;
47+ // For directives that have their device type architectures listed in
48+ // attributes (init/shutdown/etc), the list of architectures to be emitted.
49+ llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
3450 } attrData;
3551
3652 void clauseNotImplemented (const OpenACCClause &c) {
3753 cgm.errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
3854 }
3955
4056public:
41- OpenACCClauseCIREmitter (CIRGenModule &cgm) : cgm(cgm) {}
57+ OpenACCClauseCIREmitter (CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
58+ SourceLocation dirLoc)
59+ : cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
4260
4361 void VisitClause (const OpenACCClause &clause) {
4462 clauseNotImplemented (clause);
@@ -57,31 +75,92 @@ class OpenACCClauseCIREmitter final
5775 }
5876 }
5977
78+ mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *ii) {
79+ // '*' case leaves no identifier-info, just a nullptr.
80+ if (!ii)
81+ return mlir::acc::DeviceType::Star;
82+ return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName ())
83+ .CaseLower (" default" , mlir::acc::DeviceType::Default)
84+ .CaseLower (" host" , mlir::acc::DeviceType::Host)
85+ .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
86+ .CasesLower (" nvidia" , " acc_device_nvidia" ,
87+ mlir::acc::DeviceType::Nvidia)
88+ .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
89+ }
90+
91+ void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
92+
93+ switch (dirKind) {
94+ case OpenACCDirectiveKind::Init:
95+ case OpenACCDirectiveKind::Shutdown: {
96+ // Device type has a list that is either a 'star' (emitted as 'star'),
97+ // or an identifer list, all of which get added for attributes.
98+
99+ for (const DeviceTypeArgument &arg : clause.getArchitectures ())
100+ attrData.deviceTypeArchs .push_back (decodeDeviceType (arg.first ));
101+ break ;
102+ }
103+ default :
104+ return clauseNotImplemented (clause);
105+ }
106+ }
107+
60108 // Apply any of the clauses that resulted in an 'attribute'.
61- template <typename Op> void applyAttributes (Op &op) {
62- if (attrData.defaultVal .has_value ())
63- op.setDefaultAttr (*attrData.defaultVal );
109+ template <typename Op>
110+ void applyAttributes (CIRGenBuilderTy &builder, Op &op) {
111+
112+ if (attrData.defaultVal .has_value ()) {
113+ // FIXME: OpenACC: as we implement this for other directive kinds, we have
114+ // to expand this list.
115+ // This type-trait checks if 'op'(the first arg) is one of the mlir::acc
116+ // operations listed in the rest of the arguments.
117+ if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
118+ op.setDefaultAttr (*attrData.defaultVal );
119+ else
120+ cgm.errorNYI (dirLoc, " OpenACC 'default' clause lowering for " , dirKind);
121+ }
122+
123+ if (!attrData.deviceTypeArchs .empty ()) {
124+ // FIXME: OpenACC: as we implement this for other directive kinds, we have
125+ // to expand this list, or more likely, have a 'noop' branch as most other
126+ // uses of this apply to the operands instead.
127+ // This type-trait checks if 'op'(the first arg) is one of the mlir::acc
128+ if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
129+ llvm::SmallVector<mlir::Attribute> deviceTypes;
130+ for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs )
131+ deviceTypes.push_back (
132+ mlir::acc::DeviceTypeAttr::get (builder.getContext (), DT));
133+
134+ op.setDeviceTypesAttr (
135+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
136+ } else {
137+ cgm.errorNYI (dirLoc, " OpenACC 'device_type' clause lowering for " ,
138+ dirKind);
139+ }
140+ }
64141 }
65142};
143+
66144} // namespace
67145
68146template <typename Op, typename TermOp>
69147mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt (
70- mlir::Location start, mlir::Location end,
71- llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
148+ mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
149+ SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
150+ const Stmt *associatedStmt) {
72151 mlir::LogicalResult res = mlir::success ();
73152
74153 llvm::SmallVector<mlir::Type> retTy;
75154 llvm::SmallVector<mlir::Value> operands;
76155
77156 // Clause-emitter must be here because it might modify operands.
78- OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule ());
157+ OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule (), dirKind, dirLoc );
79158 clauseEmitter.VisitClauseList (clauses);
80159
81160 auto op = builder.create <Op>(start, retTy, operands);
82161
83162 // Apply the attributes derived from the clauses.
84- clauseEmitter.applyAttributes (op);
163+ clauseEmitter.applyAttributes (builder, op);
85164
86165 mlir::Block &block = op.getRegion ().emplaceBlock ();
87166 mlir::OpBuilder::InsertionGuard guardCase (builder);
@@ -95,19 +174,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
95174}
96175
97176template <typename Op>
98- mlir::LogicalResult
99- CIRGenFunction::emitOpenACCOp ( mlir::Location start,
100- llvm::ArrayRef<const OpenACCClause *> clauses) {
177+ mlir::LogicalResult CIRGenFunction::emitOpenACCOp (
178+ mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc ,
179+ llvm::ArrayRef<const OpenACCClause *> clauses) {
101180 mlir::LogicalResult res = mlir::success ();
102181
103182 llvm::SmallVector<mlir::Type> retTy;
104183 llvm::SmallVector<mlir::Value> operands;
105184
106185 // Clause-emitter must be here because it might modify operands.
107- OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule ());
186+ OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule (), dirKind, dirLoc );
108187 clauseEmitter.VisitClauseList (clauses);
109188
110- builder.create <Op>(start, retTy, operands);
189+ auto op = builder.create <Op>(start, retTy, operands);
190+ // Apply the attributes derived from the clauses.
191+ clauseEmitter.applyAttributes (builder, op);
111192 return res;
112193}
113194
@@ -119,13 +200,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
119200 switch (s.getDirectiveKind ()) {
120201 case OpenACCDirectiveKind::Parallel:
121202 return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
122- start, end, s.clauses (), s.getStructuredBlock ());
203+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
204+ s.getStructuredBlock ());
123205 case OpenACCDirectiveKind::Serial:
124206 return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
125- start, end, s.clauses (), s.getStructuredBlock ());
207+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
208+ s.getStructuredBlock ());
126209 case OpenACCDirectiveKind::Kernels:
127210 return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
128- start, end, s.clauses (), s.getStructuredBlock ());
211+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
212+ s.getStructuredBlock ());
129213 default :
130214 llvm_unreachable (" invalid compute construct kind" );
131215 }
@@ -137,18 +221,22 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
137221 mlir::Location end = getLoc (s.getSourceRange ().getEnd ());
138222
139223 return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
140- start, end, s.clauses (), s.getStructuredBlock ());
224+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
225+ s.getStructuredBlock ());
141226}
142227
143228mlir::LogicalResult
144229CIRGenFunction::emitOpenACCInitConstruct (const OpenACCInitConstruct &s) {
145230 mlir::Location start = getLoc (s.getSourceRange ().getEnd ());
146- return emitOpenACCOp<InitOp>(start, s.clauses ());
231+ return emitOpenACCOp<InitOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
232+ s.clauses ());
147233}
234+
148235mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct (
149236 const OpenACCShutdownConstruct &s) {
150237 mlir::Location start = getLoc (s.getSourceRange ().getEnd ());
151- return emitOpenACCOp<ShutdownOp>(start, s.clauses ());
238+ return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind (),
239+ s.getDirectiveLoc (), s.clauses ());
152240}
153241
154242mlir::LogicalResult
0 commit comments