99// Emit OpenACC Stmt nodes as CIR code.
1010//
1111// ===----------------------------------------------------------------------===//
12+ #include < type_traits>
1213
1314#include " CIRGenBuilder.h"
1415#include " CIRGenFunction.h"
@@ -23,22 +24,39 @@ using namespace cir;
2324using namespace mlir ::acc;
2425
2526namespace {
27+ // Simple type-trait to see if the first template arg is one of the list, so we
28+ // can tell whether to `if-constexpr` a bunch of stuff.
29+ template <typename ToTest, typename T, typename ... Tys>
30+ constexpr bool isOneOfTypes =
31+ std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
32+ template <typename ToTest, typename T>
33+ constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
34+
2635class OpenACCClauseCIREmitter final
2736 : public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
2837 CIRGenModule &cgm;
38+ // This is necessary since a few of the clauses emit differently based on the
39+ // directive kind they are attached to.
40+ OpenACCDirectiveKind dirKind;
41+ SourceLocation dirLoc;
2942
3043 struct AttributeData {
3144 // Value of the 'default' attribute, added on 'data' and 'compute'/etc
3245 // constructs as a 'default-attr'.
3346 std::optional<ClauseDefaultValue> defaultVal = std::nullopt ;
47+ // For directives that have their device type architectures listed in
48+ // attributes (init/shutdown/etc), the list of architectures to be emitted.
49+ llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
3450 } attrData;
3551
3652 void clauseNotImplemented (const OpenACCClause &c) {
3753 cgm.errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
3854 }
3955
4056public:
41- OpenACCClauseCIREmitter (CIRGenModule &cgm) : cgm(cgm) {}
57+ OpenACCClauseCIREmitter (CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
58+ SourceLocation dirLoc)
59+ : cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
4260
4361 void VisitClause (const OpenACCClause &clause) {
4462 clauseNotImplemented (clause);
@@ -57,31 +75,90 @@ class OpenACCClauseCIREmitter final
5775 }
5876 }
5977
78+ mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *II) {
79+
80+ // '*' case leaves no identifier-info, just a nullptr.
81+ if (!II)
82+ return mlir::acc::DeviceType::Star;
83+ return llvm::StringSwitch<mlir::acc::DeviceType>(II->getName ())
84+ .CaseLower (" default" , mlir::acc::DeviceType::Default)
85+ .CaseLower (" host" , mlir::acc::DeviceType::Host)
86+ .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
87+ .CasesLower (" nvidia" , " acc_device_nvidia" ,
88+ mlir::acc::DeviceType::Nvidia)
89+ .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
90+ }
91+
92+ void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
93+
94+ switch (dirKind) {
95+ case OpenACCDirectiveKind::Init:
96+ case OpenACCDirectiveKind::Shutdown: {
97+ // Device type has a list that is either a 'star' (emitted as 'star'),
98+ // or an identifer list, all of which get added for attributes.
99+
100+ for (const DeviceTypeArgument &Arg : clause.getArchitectures ())
101+ attrData.deviceTypeArchs .push_back (decodeDeviceType (Arg.first ));
102+ break ;
103+ }
104+ default :
105+ return clauseNotImplemented (clause);
106+ }
107+ }
108+
60109 // Apply any of the clauses that resulted in an 'attribute'.
61- template <typename Op> void applyAttributes (Op &op) {
62- if (attrData.defaultVal .has_value ())
63- op.setDefaultAttr (*attrData.defaultVal );
110+ template <typename Op>
111+ void applyAttributes (CIRGenBuilderTy &builder, Op &op) {
112+
113+ if (attrData.defaultVal .has_value ()) {
114+ // FIXME: OpenACC: as we implement this for other directive kinds, we have
115+ // to expand this list.
116+ if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
117+ op.setDefaultAttr (*attrData.defaultVal );
118+ else
119+ cgm.errorNYI (dirLoc, " OpenACC 'default' clause lowering for " , dirKind);
120+ }
121+
122+ if (!attrData.deviceTypeArchs .empty ()) {
123+ // FIXME: OpenACC: as we implement this for other directive kinds, we have
124+ // to expand this list, or more likely, have a 'noop' branch as most other
125+ // uses of this apply to the operands instead.
126+ if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
127+ llvm::SmallVector<mlir::Attribute> deviceTypes;
128+ for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs )
129+ deviceTypes.push_back (
130+ mlir::acc::DeviceTypeAttr::get (builder.getContext (), DT));
131+
132+ op.setDeviceTypesAttr (
133+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
134+ } else {
135+ cgm.errorNYI (dirLoc, " OpenACC 'device_type' clause lowering for " ,
136+ dirKind);
137+ }
138+ }
64139 }
65140};
141+
66142} // namespace
67143
68144template <typename Op, typename TermOp>
69145mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt (
70- mlir::Location start, mlir::Location end,
71- llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
146+ OpenACCDirectiveKind dirKind, SourceLocation dirLoc, mlir::Location start,
147+ mlir::Location end, llvm::ArrayRef<const OpenACCClause *> clauses,
148+ const Stmt *associatedStmt) {
72149 mlir::LogicalResult res = mlir::success ();
73150
74151 llvm::SmallVector<mlir::Type> retTy;
75152 llvm::SmallVector<mlir::Value> operands;
76153
77154 // Clause-emitter must be here because it might modify operands.
78- OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule ());
155+ OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule (), dirKind, dirLoc );
79156 clauseEmitter.VisitClauseList (clauses);
80157
81158 auto op = builder.create <Op>(start, retTy, operands);
82159
83160 // Apply the attributes derived from the clauses.
84- clauseEmitter.applyAttributes (op);
161+ clauseEmitter.applyAttributes (builder, op);
85162
86163 mlir::Block &block = op.getRegion ().emplaceBlock ();
87164 mlir::OpBuilder::InsertionGuard guardCase (builder);
@@ -96,18 +173,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
96173
97174template <typename Op>
98175mlir::LogicalResult
99- CIRGenFunction::emitOpenACCOp (mlir::Location start,
176+ CIRGenFunction::emitOpenACCOp (OpenACCDirectiveKind dirKind,
177+ SourceLocation dirLoc, mlir::Location start,
100178 llvm::ArrayRef<const OpenACCClause *> clauses) {
101179 mlir::LogicalResult res = mlir::success ();
102180
103181 llvm::SmallVector<mlir::Type> retTy;
104182 llvm::SmallVector<mlir::Value> operands;
105183
106184 // Clause-emitter must be here because it might modify operands.
107- OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule ());
185+ OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule (), dirKind, dirLoc );
108186 clauseEmitter.VisitClauseList (clauses);
109187
110- builder.create <Op>(start, retTy, operands);
188+ auto op = builder.create <Op>(start, retTy, operands);
189+ // Apply the attributes derived from the clauses.
190+ clauseEmitter.applyAttributes (builder, op);
111191 return res;
112192}
113193
@@ -119,13 +199,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
119199 switch (s.getDirectiveKind ()) {
120200 case OpenACCDirectiveKind::Parallel:
121201 return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
122- start, end, s.clauses (), s.getStructuredBlock ());
202+ s.getDirectiveKind (), s.getDirectiveLoc (), start, end, s.clauses (),
203+ s.getStructuredBlock ());
123204 case OpenACCDirectiveKind::Serial:
124205 return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
125- start, end, s.clauses (), s.getStructuredBlock ());
206+ s.getDirectiveKind (), s.getDirectiveLoc (), start, end, s.clauses (),
207+ s.getStructuredBlock ());
126208 case OpenACCDirectiveKind::Kernels:
127209 return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
128- start, end, s.clauses (), s.getStructuredBlock ());
210+ s.getDirectiveKind (), s.getDirectiveLoc (), start, end, s.clauses (),
211+ s.getStructuredBlock ());
129212 default :
130213 llvm_unreachable (" invalid compute construct kind" );
131214 }
@@ -137,18 +220,21 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
137220 mlir::Location end = getLoc (s.getSourceRange ().getEnd ());
138221
139222 return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
140- start, end, s.clauses (), s.getStructuredBlock ());
223+ s.getDirectiveKind (), s.getDirectiveLoc (), start, end, s.clauses (),
224+ s.getStructuredBlock ());
141225}
142226
143227mlir::LogicalResult
144228CIRGenFunction::emitOpenACCInitConstruct (const OpenACCInitConstruct &s) {
145229 mlir::Location start = getLoc (s.getSourceRange ().getEnd ());
146- return emitOpenACCOp<InitOp>(start, s.clauses ());
230+ return emitOpenACCOp<InitOp>(s.getDirectiveKind (), s.getDirectiveLoc (), start,
231+ s.clauses ());
147232}
148233mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct (
149234 const OpenACCShutdownConstruct &s) {
150235 mlir::Location start = getLoc (s.getSourceRange ().getEnd ());
151- return emitOpenACCOp<ShutdownOp>(start, s.clauses ());
236+ return emitOpenACCOp<ShutdownOp>(s.getDirectiveKind (), s.getDirectiveLoc (),
237+ start, s.clauses ());
152238}
153239
154240mlir::LogicalResult
0 commit comments