@@ -95,19 +95,78 @@ class OpenACCClauseCIREmitter final
9595 .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
9696 }
9797
98- // Handle a clause affected by the 'device-type' to the point that they need
99- // to have the attributes added in the correct/corresponding order, such as
100- // 'num_workers' or 'vector_length' on a compute construct. For cases where we
101- // don't have an expression 'argument' that needs to be added to an operand
102- // and only care about the 'device-type' list, we can use this with 'argument'
103- // as 'std::nullopt'. If 'argument' is NOT 'std::nullopt' (that is, has a
104- // value), argCollection must also be non-null. For cases where we don't have
105- // an argument that needs to be added to an additional one (such as asyncOnly)
106- // we can use this with 'argument' as std::nullopt.
107- mlir::ArrayAttr handleDeviceTypeAffectedClause (
108- mlir::ArrayAttr existingDeviceTypes,
109- std::optional<mlir::Value> argument = std::nullopt ,
110- mlir::MutableOperandRange *argCollection = nullptr ) {
98+ // Overload of this function that only returns the device-types list.
99+ mlir::ArrayAttr
100+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes) {
101+ mlir::ValueRange argument;
102+ mlir::MutableOperandRange range{operation};
103+
104+ return handleDeviceTypeAffectedClause (existingDeviceTypes, argument, range);
105+ }
106+ // Overload of this function for when 'segments' aren't necessary.
107+ mlir::ArrayAttr
108+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
109+ mlir::ValueRange argument,
110+ mlir::MutableOperandRange argCollection) {
111+ llvm::SmallVector<int32_t > segments;
112+ assert (argument.size () <= 1 &&
113+ " Overload only for cases where segments don't need to be added" );
114+ return handleDeviceTypeAffectedClause (existingDeviceTypes, argument,
115+ argCollection, segments);
116+ }
117+
118+ // Handle a clause affected by the 'device_type' to the point that they need
119+ // to have attributes added in the correct/corresponding order, such as
120+ // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121+ // a collection of operands that need to be appended to the `argCollection` as
122+ // we're adding a 'device_type' entry. If there is more than 0 elements in
123+ // the 'argument', the collection must be non-null, as it is needed to add to
124+ // it.
125+ // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126+ // be maintained, this takes a list of segments that will be updated with the
127+ // proper counts as 'argument' elements are added.
128+ //
129+ // In MLIR, the 'operands' are stored as a large array, with a separate array
130+ // of 'segments' that show which 'operand' applies to which 'operand-kind'.
131+ // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
132+ //
133+ // So the operands array might have 4 elements, but the 'segments' array will
134+ // be something like:
135+ //
136+ // {0, 0, 0, 2, 0, 1, 1, 0, 0...}
137+ //
138+ // Where each position belongs to a specific 'operand-kind'. So that
139+ // specifies that whichever operand-kind corresponds with index '3' has 2
140+ // elements, and should take the 1st 2 operands off the list (since all
141+ // preceding values are 0). operand-kinds corresponding to 5 and 6 each have
142+ // 1 element.
143+ //
144+ // Fortunately, the `MutableOperandRange` append function actually takes care
145+ // of that for us at the 'top level'.
146+ //
147+ // However, in cases like `num_gangs' or 'wait', where each individual
148+ // 'element' might be itself array-like, there is a separate 'segments' array
149+ // for them. So in the case of:
150+ //
151+ // device_type(nvidia, radeon) num_gangs(1, 2, 3)
152+ //
153+ // We have to emit that as TWO arrays into the IR (where the device_type is an
154+ // attribute), so they look like:
155+ //
156+ // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
157+ // {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
158+ //
159+ // When stored in the 'operands' list, the top-level 'segment' for
160+ // 'num_gangs' just shows 6 elements. In order to get the array-like
161+ // apperance, the 'numGangsSegments' list is kept as well. In the above case,
162+ // we've inserted 6 operands, so the 'numGangsSegments' must contain 2
163+ // elements, 1 per array, and each will have a value of 3. The verifier will
164+ // ensure that the collections counts are correct.
165+ mlir::ArrayAttr
166+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
167+ mlir::ValueRange argument,
168+ mlir::MutableOperandRange argCollection,
169+ llvm::SmallVector<int32_t > &segments) {
111170 llvm::SmallVector<mlir::Attribute> deviceTypes;
112171
113172 // Collect the 'existing' device-type attributes so we can re-create them
@@ -126,18 +185,18 @@ class OpenACCClauseCIREmitter final
126185 lastDeviceTypeClause->getArchitectures ()) {
127186 deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
128187 builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
129- if (argument) {
130- assert ( argCollection);
131- argCollection-> append (* argument);
188+ if (! argument. empty () ) {
189+ argCollection. append (argument );
190+ segments. push_back ( argument. size () );
132191 }
133192 }
134193 } else {
135194 // Else, we just add a single for 'none'.
136195 deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
137196 builder.getContext (), mlir::acc::DeviceType::None));
138- if (argument) {
139- assert ( argCollection);
140- argCollection-> append (* argument);
197+ if (! argument. empty () ) {
198+ argCollection. append (argument );
199+ segments. push_back ( argument. size () );
141200 }
142201 }
143202
@@ -170,7 +229,8 @@ class OpenACCClauseCIREmitter final
170229 break ;
171230 }
172231 } else {
173- // Combined Constructs left.
232+ // TODO: When we've implemented this for everything, switch this to an
233+ // unreachable. Combined constructs remain.
174234 return clauseNotImplemented (clause);
175235 }
176236 }
@@ -210,7 +270,8 @@ class OpenACCClauseCIREmitter final
210270 // they just modify the other clauses IR. So setting of `lastDeviceType`
211271 // (done above) is all we need.
212272 } else {
213- // update, data, loop, routine, combined remain.
273+ // TODO: When we've implemented this for everything, switch this to an
274+ // unreachable. update, data, loop, routine, combined constructs remain.
214275 return clauseNotImplemented (clause);
215276 }
216277 }
@@ -220,11 +281,12 @@ class OpenACCClauseCIREmitter final
220281 mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
221282 operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
222283 operation.getNumWorkersDeviceTypeAttr (),
223- createIntExpr (clause.getIntExpr ()), & range));
284+ createIntExpr (clause.getIntExpr ()), range));
224285 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
225286 llvm_unreachable (" num_workers not valid on serial" );
226287 } else {
227- // Combined Remain.
288+ // TODO: When we've implemented this for everything, switch this to an
289+ // unreachable. Combined constructs remain.
228290 return clauseNotImplemented (clause);
229291 }
230292 }
@@ -234,11 +296,12 @@ class OpenACCClauseCIREmitter final
234296 mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
235297 operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
236298 operation.getVectorLengthDeviceTypeAttr (),
237- createIntExpr (clause.getIntExpr ()), & range));
299+ createIntExpr (clause.getIntExpr ()), range));
238300 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
239301 llvm_unreachable (" vector_length not valid on serial" );
240302 } else {
241- // Combined remain.
303+ // TODO: When we've implemented this for everything, switch this to an
304+ // unreachable. Combined constructs remain.
242305 return clauseNotImplemented (clause);
243306 }
244307 }
@@ -252,10 +315,12 @@ class OpenACCClauseCIREmitter final
252315 mlir::MutableOperandRange range = operation.getAsyncOperandsMutable ();
253316 operation.setAsyncOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
254317 operation.getAsyncOperandsDeviceTypeAttr (),
255- createIntExpr (clause.getIntExpr ()), & range));
318+ createIntExpr (clause.getIntExpr ()), range));
256319 }
257320 } else {
258- // Data, enter data, exit data, update, wait, combined remain.
321+ // TODO: When we've implemented this for everything, switch this to an
322+ // unreachable. Combined constructs remain. Data, enter data, exit data,
323+ // update, wait, combined constructs remain.
259324 return clauseNotImplemented (clause);
260325 }
261326 }
@@ -272,7 +337,8 @@ class OpenACCClauseCIREmitter final
272337 llvm_unreachable (" var-list version of self shouldn't get here" );
273338 }
274339 } else {
275- // update and combined remain.
340+ // TODO: When we've implemented this for everything, switch this to an
341+ // unreachable. If, combined constructs remain.
276342 return clauseNotImplemented (clause);
277343 }
278344 }
@@ -286,7 +352,9 @@ class OpenACCClauseCIREmitter final
286352 // 'if' applies to most of the constructs, but hold off on lowering them
287353 // until we can write tests/know what we're doing with codegen to make
288354 // sure we get it right.
289- // Enter data, exit data, host_data, update, wait, combined remain.
355+ // TODO: When we've implemented this for everything, switch this to an
356+ // unreachable. Enter data, exit data, host_data, update, wait, combined
357+ // constructs remain.
290358 return clauseNotImplemented (clause);
291359 }
292360 }
@@ -301,6 +369,29 @@ class OpenACCClauseCIREmitter final
301369 }
302370 }
303371
372+ void VisitNumGangsClause (const OpenACCNumGangsClause &clause) {
373+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
374+ llvm::SmallVector<mlir::Value> values;
375+
376+ for (const Expr *E : clause.getIntExprs ())
377+ values.push_back (createIntExpr (E));
378+
379+ llvm::SmallVector<int32_t > segments;
380+ if (operation.getNumGangsSegments ())
381+ llvm::copy (*operation.getNumGangsSegments (),
382+ std::back_inserter (segments));
383+
384+ mlir::MutableOperandRange range = operation.getNumGangsMutable ();
385+ operation.setNumGangsDeviceTypeAttr (handleDeviceTypeAffectedClause (
386+ operation.getNumGangsDeviceTypeAttr (), values, range, segments));
387+ operation.setNumGangsSegments (llvm::ArrayRef<int32_t >{segments});
388+ } else {
389+ // TODO: When we've implemented this for everything, switch this to an
390+ // unreachable. Combined constructs remain.
391+ return clauseNotImplemented (clause);
392+ }
393+ }
394+
304395 void VisitDefaultAsyncClause (const OpenACCDefaultAsyncClause &clause) {
305396 if constexpr (isOneOfTypes<OpTy, SetOp>) {
306397 operation.getDefaultAsyncMutable ().append (
0 commit comments