@@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final
4646 // diagnostics are gone.
4747 SourceLocation dirLoc;
4848
49- const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr ;
49+ llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
50+
51+ void setLastDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
52+ lastDeviceTypeValues.clear ();
53+
54+ llvm::for_each (clause.getArchitectures (),
55+ [this ](const DeviceTypeArgument &arg) {
56+ lastDeviceTypeValues.push_back (
57+ decodeDeviceType (arg.getIdentifierInfo ()));
58+ });
59+ }
5060
5161 void clauseNotImplemented (const OpenACCClause &c) {
5262 cgf.cgm .errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
@@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final
95105 .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
96106 }
97107
98- // Overload of this function that only returns the device-types list.
99- mlir::ArrayAttr
100- handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes) {
101- mlir::ValueRange argument;
102- mlir::MutableOperandRange range{operation};
103-
104- return handleDeviceTypeAffectedClause (existingDeviceTypes, argument, range);
105- }
106- // Overload of this function for when 'segments' aren't necessary.
107- mlir::ArrayAttr
108- handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
109- mlir::ValueRange argument,
110- mlir::MutableOperandRange argCollection) {
111- llvm::SmallVector<int32_t > segments;
112- assert (argument.size () <= 1 &&
113- " Overload only for cases where segments don't need to be added" );
114- return handleDeviceTypeAffectedClause (existingDeviceTypes, argument,
115- argCollection, segments);
116- }
117-
118- // Handle a clause affected by the 'device_type' to the point that they need
119- // to have attributes added in the correct/corresponding order, such as
120- // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121- // a collection of operands that need to be appended to the `argCollection` as
122- // we're adding a 'device_type' entry. If there is more than 0 elements in
123- // the 'argument', the collection must be non-null, as it is needed to add to
124- // it.
125- // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126- // be maintained, this takes a list of segments that will be updated with the
127- // proper counts as 'argument' elements are added.
128- //
129- // In MLIR, the 'operands' are stored as a large array, with a separate array
130- // of 'segments' that show which 'operand' applies to which 'operand-kind'.
131- // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
132- //
133- // So the operands array might have 4 elements, but the 'segments' array will
134- // be something like:
135- //
136- // {0, 0, 0, 2, 0, 1, 1, 0, 0...}
137- //
138- // Where each position belongs to a specific 'operand-kind'. So that
139- // specifies that whichever operand-kind corresponds with index '3' has 2
140- // elements, and should take the 1st 2 operands off the list (since all
141- // preceding values are 0). operand-kinds corresponding to 5 and 6 each have
142- // 1 element.
143- //
144- // Fortunately, the `MutableOperandRange` append function actually takes care
145- // of that for us at the 'top level'.
146- //
147- // However, in cases like `num_gangs' or 'wait', where each individual
148- // 'element' might be itself array-like, there is a separate 'segments' array
149- // for them. So in the case of:
150- //
151- // device_type(nvidia, radeon) num_gangs(1, 2, 3)
152- //
153- // We have to emit that as TWO arrays into the IR (where the device_type is an
154- // attribute), so they look like:
155- //
156- // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
157- // {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
158- //
159- // When stored in the 'operands' list, the top-level 'segment' for
160- // 'num_gangs' just shows 6 elements. In order to get the array-like
161- // apperance, the 'numGangsSegments' list is kept as well. In the above case,
162- // we've inserted 6 operands, so the 'numGangsSegments' must contain 2
163- // elements, 1 per array, and each will have a value of 3. The verifier will
164- // ensure that the collections counts are correct.
165- mlir::ArrayAttr
166- handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
167- mlir::ValueRange argument,
168- mlir::MutableOperandRange argCollection,
169- llvm::SmallVector<int32_t > &segments) {
170- llvm::SmallVector<mlir::Attribute> deviceTypes;
171-
172- // Collect the 'existing' device-type attributes so we can re-create them
173- // and insert them.
174- if (existingDeviceTypes) {
175- for (const mlir::Attribute &Attr : existingDeviceTypes)
176- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
177- builder.getContext (),
178- cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
179- }
180-
181- // Insert 1 version of the 'expr' to the NumWorkers list per-current
182- // device type.
183- if (lastDeviceTypeClause) {
184- for (const DeviceTypeArgument &arch :
185- lastDeviceTypeClause->getArchitectures ()) {
186- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
187- builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
188- if (!argument.empty ()) {
189- argCollection.append (argument);
190- segments.push_back (argument.size ());
191- }
192- }
193- } else {
194- // Else, we just add a single for 'none'.
195- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
196- builder.getContext (), mlir::acc::DeviceType::None));
197- if (!argument.empty ()) {
198- argCollection.append (argument);
199- segments.push_back (argument.size ());
200- }
201- }
202-
203- return mlir::ArrayAttr::get (builder.getContext (), deviceTypes);
204- }
205-
206108public:
207109 OpenACCClauseCIREmitter (OpTy &operation, CIRGenFunction &cgf,
208110 CIRGenBuilderTy &builder,
@@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final
236138 }
237139
238140 void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
239- lastDeviceTypeClause = &clause;
141+ setLastDeviceTypeClause (clause);
142+
240143 if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
241144 llvm::for_each (
242145 clause.getArchitectures (), [this ](const DeviceTypeArgument &arg) {
@@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final
253156 } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp,
254157 DataOp>) {
255158 // Nothing to do here, these constructs don't have any IR for these, as
256- // they just modify the other clauses IR. So setting of `lastDeviceType`
257- // (done above) is all we need.
159+ // they just modify the other clauses IR. So setting of
160+ // `lastDeviceTypeValues` (done above) is all we need.
258161 } else {
259162 // TODO: When we've implemented this for everything, switch this to an
260163 // unreachable. update, data, loop, routine, combined constructs remain.
@@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final
264167
265168 void VisitNumWorkersClause (const OpenACCNumWorkersClause &clause) {
266169 if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
267- mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
268- operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
269- operation.getNumWorkersDeviceTypeAttr (),
270- createIntExpr (clause.getIntExpr ()), range));
170+ operation.addNumWorkersOperand (builder.getContext (),
171+ createIntExpr (clause.getIntExpr ()),
172+ lastDeviceTypeValues);
271173 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
272174 llvm_unreachable (" num_workers not valid on serial" );
273175 } else {
@@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final
279181
280182 void VisitVectorLengthClause (const OpenACCVectorLengthClause &clause) {
281183 if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
282- mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
283- operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
284- operation.getVectorLengthDeviceTypeAttr (),
285- createIntExpr (clause.getIntExpr ()), range));
184+ operation.addVectorLengthOperand (builder.getContext (),
185+ createIntExpr (clause.getIntExpr ()),
186+ lastDeviceTypeValues);
286187 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
287188 llvm_unreachable (" vector_length not valid on serial" );
288189 } else {
@@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final
294195
295196 void VisitAsyncClause (const OpenACCAsyncClause &clause) {
296197 if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
297- if (!clause.hasIntExpr ()) {
298- operation.setAsyncOnlyAttr (
299- handleDeviceTypeAffectedClause (operation.getAsyncOnlyAttr ()));
300- } else {
301- mlir::MutableOperandRange range = operation.getAsyncOperandsMutable ();
302- operation.setAsyncOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
303- operation.getAsyncOperandsDeviceTypeAttr (),
304- createIntExpr (clause.getIntExpr ()), range));
305- }
198+ if (!clause.hasIntExpr ())
199+ operation.addAsyncOnly (builder.getContext (), lastDeviceTypeValues);
200+ else
201+ operation.addAsyncOperand (builder.getContext (),
202+ createIntExpr (clause.getIntExpr ()),
203+ lastDeviceTypeValues);
306204 } else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
307205 // Wait doesn't have a device_type, so its handling here is slightly
308206 // different.
@@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final
366264 void VisitNumGangsClause (const OpenACCNumGangsClause &clause) {
367265 if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
368266 llvm::SmallVector<mlir::Value> values;
369-
370267 for (const Expr *E : clause.getIntExprs ())
371268 values.push_back (createIntExpr (E));
372269
373- llvm::SmallVector<int32_t > segments;
374- if (operation.getNumGangsSegments ())
375- llvm::copy (*operation.getNumGangsSegments (),
376- std::back_inserter (segments));
377-
378- mlir::MutableOperandRange range = operation.getNumGangsMutable ();
379- operation.setNumGangsDeviceTypeAttr (handleDeviceTypeAffectedClause (
380- operation.getNumGangsDeviceTypeAttr (), values, range, segments));
381- operation.setNumGangsSegments (llvm::ArrayRef<int32_t >{segments});
270+ operation.addNumGangsOperands (builder.getContext (), values,
271+ lastDeviceTypeValues);
382272 } else {
383273 // TODO: When we've implemented this for everything, switch this to an
384274 // unreachable. Combined constructs remain.
@@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final
389279 void VisitWaitClause (const OpenACCWaitClause &clause) {
390280 if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
391281 if (!clause.hasExprs ()) {
392- operation.setWaitOnlyAttr (
393- handleDeviceTypeAffectedClause (operation.getWaitOnlyAttr ()));
282+ operation.addWaitOnly (builder.getContext (), lastDeviceTypeValues);
394283 } else {
395284 llvm::SmallVector<mlir::Value> values;
396-
397285 if (clause.hasDevNumExpr ())
398286 values.push_back (createIntExpr (clause.getDevNumExpr ()));
399287 for (const Expr *E : clause.getQueueIdExprs ())
400288 values.push_back (createIntExpr (E));
401-
402- llvm::SmallVector<int32_t > segments;
403- if (operation.getWaitOperandsSegments ())
404- llvm::copy (*operation.getWaitOperandsSegments (),
405- std::back_inserter (segments));
406-
407- unsigned beforeSegmentSize = segments.size ();
408-
409- mlir::MutableOperandRange range = operation.getWaitOperandsMutable ();
410- operation.setWaitOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
411- operation.getWaitOperandsDeviceTypeAttr (), values, range,
412- segments));
413- operation.setWaitOperandsSegments (segments);
414-
415- // In addition to having to set the 'segments', wait also has a list of
416- // bool attributes whether it is annotated with 'devnum'. We can use
417- // our knowledge of how much the 'segments' array grew to determine how
418- // many we need to add.
419- llvm::SmallVector<bool > hasDevNums;
420- if (operation.getHasWaitDevnumAttr ())
421- for (mlir::Attribute A : operation.getHasWaitDevnumAttr ())
422- hasDevNums.push_back (cast<mlir::BoolAttr>(A).getValue ());
423-
424- hasDevNums.insert (hasDevNums.end (), segments.size () - beforeSegmentSize,
425- clause.hasDevNumExpr ());
426-
427- operation.setHasWaitDevnumAttr (builder.getBoolArrayAttr (hasDevNums));
289+ operation.addWaitOperands (builder.getContext (), clause.hasDevNumExpr (),
290+ values, lastDeviceTypeValues);
428291 }
429292 } else {
430293 // TODO: When we've implemented this for everything, switch this to an
@@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
589452 if (s.hasDevNumExpr ())
590453 waitOp.getWaitDevnumMutable ().append (createIntExpr (s.getDevNumExpr ()));
591454
592- for (Expr *QueueExpr : s.getQueueIdExprs ())
455+ for (Expr *QueueExpr : s.getQueueIdExprs ())
593456 waitOp.getWaitOperandsMutable ().append (createIntExpr (QueueExpr));
594457 }
595458
0 commit comments