@@ -82,6 +82,56 @@ class OpenACCClauseCIREmitter final
8282 return conversionOp.getResult (0 );
8383 }
8484
85+ mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *ii) {
86+ // '*' case leaves no identifier-info, just a nullptr.
87+ if (!ii)
88+ return mlir::acc::DeviceType::Star;
89+ return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName ())
90+ .CaseLower (" default" , mlir::acc::DeviceType::Default)
91+ .CaseLower (" host" , mlir::acc::DeviceType::Host)
92+ .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
93+ .CasesLower (" nvidia" , " acc_device_nvidia" ,
94+ mlir::acc::DeviceType::Nvidia)
95+ .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
96+ }
97+
98+ // Handle a clause affected by the 'device-type' to the point that they need
99+ // to have the attributes added in the correct/corresponding order, such as
100+ // 'num_workers' or 'vector_length' on a compute construct.
101+ mlir::ArrayAttr
102+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
103+ mlir::Value argument,
104+ mlir::MutableOperandRange &argCollection) {
105+ llvm::SmallVector<mlir::Attribute> deviceTypes;
106+
107+ // Collect the 'existing' device-type attributes so we can re-create them
108+ // and insert them.
109+ if (existingDeviceTypes) {
110+ for (const mlir::Attribute &Attr : existingDeviceTypes)
111+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
112+ builder.getContext (),
113+ cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
114+ }
115+
116+ // Insert 1 version of the 'expr' to the NumWorkers list per-current
117+ // device type.
118+ if (lastDeviceTypeClause) {
119+ for (const DeviceTypeArgument &arch :
120+ lastDeviceTypeClause->getArchitectures ()) {
121+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
122+ builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
123+ argCollection.append (argument);
124+ }
125+ } else {
126+ // Else, we just add a single for 'none'.
127+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
128+ builder.getContext (), mlir::acc::DeviceType::None));
129+ argCollection.append (argument);
130+ }
131+
132+ return mlir::ArrayAttr::get (builder.getContext (), deviceTypes);
133+ }
134+
85135public:
86136 OpenACCClauseCIREmitter (OpTy &operation, CIRGenFunction &cgf,
87137 CIRGenBuilderTy &builder,
@@ -112,19 +162,6 @@ class OpenACCClauseCIREmitter final
112162 }
113163 }
114164
115- mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *ii) {
116- // '*' case leaves no identifier-info, just a nullptr.
117- if (!ii)
118- return mlir::acc::DeviceType::Star;
119- return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName ())
120- .CaseLower (" default" , mlir::acc::DeviceType::Default)
121- .CaseLower (" host" , mlir::acc::DeviceType::Host)
122- .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
123- .CasesLower (" nvidia" , " acc_device_nvidia" ,
124- mlir::acc::DeviceType::Nvidia)
125- .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
126- }
127-
128165 void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
129166 lastDeviceTypeClause = &clause;
130167 if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
@@ -165,45 +202,30 @@ class OpenACCClauseCIREmitter final
165202
166203 void VisitNumWorkersClause (const OpenACCNumWorkersClause &clause) {
167204 if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
168- // Collect the 'existing' device-type attributes so we can re-create them
169- // and insert them.
170- llvm::SmallVector<mlir::Attribute> deviceTypes;
171- mlir::ArrayAttr existingDeviceTypes =
172- operation.getNumWorkersDeviceTypeAttr ();
173-
174- if (existingDeviceTypes) {
175- for (mlir::Attribute Attr : existingDeviceTypes)
176- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
177- builder.getContext (),
178- cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
179- }
180-
181- // Insert 1 version of the 'int-expr' to the NumWorkers list per-current
182- // device type.
183- mlir::Value intExpr = createIntExpr (clause.getIntExpr ());
184- if (lastDeviceTypeClause) {
185- for (const DeviceTypeArgument &arg :
186- lastDeviceTypeClause->getArchitectures ()) {
187- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
188- builder.getContext (), decodeDeviceType (arg.getIdentifierInfo ())));
189- operation.getNumWorkersMutable ().append (intExpr);
190- }
191- } else {
192- // Else, we just add a single for 'none'.
193- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
194- builder.getContext (), mlir::acc::DeviceType::None));
195- operation.getNumWorkersMutable ().append (intExpr);
196- }
197-
198- operation.setNumWorkersDeviceTypeAttr (
199- mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
205+ mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
206+ operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
207+ operation.getNumWorkersDeviceTypeAttr (),
208+ createIntExpr (clause.getIntExpr ()), range));
200209 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
201210 llvm_unreachable (" num_workers not valid on serial" );
202211 } else {
203212 return clauseNotImplemented (clause);
204213 }
205214 }
206215
216+ void VisitVectorLengthClause (const OpenACCVectorLengthClause &clause) {
217+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
218+ mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
219+ operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
220+ operation.getVectorLengthDeviceTypeAttr (),
221+ createIntExpr (clause.getIntExpr ()), range));
222+ } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
223+ llvm_unreachable (" vector_length not valid on serial" );
224+ } else {
225+ return clauseNotImplemented (clause);
226+ }
227+ }
228+
207229 void VisitSelfClause (const OpenACCSelfClause &clause) {
208230 if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp>) {
209231 if (clause.isEmptySelfClause ()) {
0 commit comments