@@ -95,19 +95,41 @@ class OpenACCClauseCIREmitter final
9595 .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
9696 }
9797
98- // Handle a clause affected by the 'device-type' to the point that they need
99- // to have the attributes added in the correct/corresponding order, such as
100- // 'num_workers' or 'vector_length' on a compute construct. For cases where we
101- // don't have an expression 'argument' that needs to be added to an operand
102- // and only care about the 'device-type' list, we can use this with 'argument'
103- // as 'std::nullopt'. If 'argument' is NOT 'std::nullopt' (that is, has a
104- // value), argCollection must also be non-null. For cases where we don't have
105- // an argument that needs to be added to an additional one (such as asyncOnly)
106- // we can use this with 'argument' as std::nullopt.
107- mlir::ArrayAttr handleDeviceTypeAffectedClause (
108- mlir::ArrayAttr existingDeviceTypes,
109- std::optional<mlir::Value> argument = std::nullopt ,
110- mlir::MutableOperandRange *argCollection = nullptr ) {
98+ // Overload of this function that only returns the device-types list.
99+ mlir::ArrayAttr
100+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes) {
101+ mlir::ValueRange argument;
102+ mlir::MutableOperandRange range{operation};
103+
104+ return handleDeviceTypeAffectedClause (existingDeviceTypes, argument, range);
105+ }
106+ // Overload of this function for when 'segments' aren't necessary.
107+ mlir::ArrayAttr
108+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
109+ mlir::ValueRange argument,
110+ mlir::MutableOperandRange argCollection) {
111+ llvm::SmallVector<int32_t > segments;
112+ assert (argument.size () <= 1 &&
113+ " Overload only for cases where segments don't need to be added" );
114+ return handleDeviceTypeAffectedClause (existingDeviceTypes, argument,
115+ argCollection, segments);
116+ }
117+
118+ // Handle a clause affected by the 'device_type' to the point that they need
119+ // to have attributes added in the correct/corresponding order, such as
120+ // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121+ // a collection of operands that need to be appended to the `argCollection` as
122+ // we're adding a 'device_type' entry. If there is more than 0 elements in
123+ // the 'argument', the collection must be non-null, as it is needed to add to
124+ // it.
125+ // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126+ // be maintained, this takes a list of segments that will be updated with the
127+ // proper counts as 'argument' elements are added.
128+ mlir::ArrayAttr
129+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
130+ mlir::ValueRange argument,
131+ mlir::MutableOperandRange argCollection,
132+ llvm::SmallVector<int32_t > &segments) {
111133 llvm::SmallVector<mlir::Attribute> deviceTypes;
112134
113135 // Collect the 'existing' device-type attributes so we can re-create them
@@ -126,18 +148,18 @@ class OpenACCClauseCIREmitter final
126148 lastDeviceTypeClause->getArchitectures ()) {
127149 deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
128150 builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
129- if (argument) {
130- assert ( argCollection);
131- argCollection-> append (* argument);
151+ if (! argument. empty () ) {
152+ argCollection. append (argument );
153+ segments. push_back ( argument. size () );
132154 }
133155 }
134156 } else {
135157 // Else, we just add a single for 'none'.
136158 deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
137159 builder.getContext (), mlir::acc::DeviceType::None));
138- if (argument) {
139- assert ( argCollection);
140- argCollection-> append (* argument);
160+ if (! argument. empty () ) {
161+ argCollection. append (argument );
162+ segments. push_back ( argument. size () );
141163 }
142164 }
143165
@@ -220,7 +242,7 @@ class OpenACCClauseCIREmitter final
220242 mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
221243 operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
222244 operation.getNumWorkersDeviceTypeAttr (),
223- createIntExpr (clause.getIntExpr ()), & range));
245+ createIntExpr (clause.getIntExpr ()), range));
224246 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
225247 llvm_unreachable (" num_workers not valid on serial" );
226248 } else {
@@ -234,7 +256,7 @@ class OpenACCClauseCIREmitter final
234256 mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
235257 operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
236258 operation.getVectorLengthDeviceTypeAttr (),
237- createIntExpr (clause.getIntExpr ()), & range));
259+ createIntExpr (clause.getIntExpr ()), range));
238260 } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
239261 llvm_unreachable (" vector_length not valid on serial" );
240262 } else {
@@ -252,7 +274,7 @@ class OpenACCClauseCIREmitter final
252274 mlir::MutableOperandRange range = operation.getAsyncOperandsMutable ();
253275 operation.setAsyncOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
254276 operation.getAsyncOperandsDeviceTypeAttr (),
255- createIntExpr (clause.getIntExpr ()), & range));
277+ createIntExpr (clause.getIntExpr ()), range));
256278 }
257279 } else {
258280 // Data, enter data, exit data, update, wait, combined remain.
@@ -301,6 +323,28 @@ class OpenACCClauseCIREmitter final
301323 }
302324 }
303325
326+ void VisitNumGangsClause (const OpenACCNumGangsClause &clause) {
327+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
328+ llvm::SmallVector<mlir::Value> values;
329+
330+ for (const Expr *E : clause.getIntExprs ())
331+ values.push_back (createIntExpr (E));
332+
333+ llvm::SmallVector<int32_t > segments;
334+ if (operation.getNumGangsSegments ())
335+ llvm::copy (*operation.getNumGangsSegments (),
336+ std::back_inserter (segments));
337+
338+ mlir::MutableOperandRange range = operation.getNumGangsMutable ();
339+ operation.setNumGangsDeviceTypeAttr (handleDeviceTypeAffectedClause (
340+ operation.getNumGangsDeviceTypeAttr (), values, range, segments));
341+ operation.setNumGangsSegments (llvm::ArrayRef<int32_t >{segments});
342+ } else {
343+ // combined remains.
344+ return clauseNotImplemented (clause);
345+ }
346+ }
347+
304348 void VisitDefaultAsyncClause (const OpenACCDefaultAsyncClause &clause) {
305349 if constexpr (isOneOfTypes<OpTy, SetOp>) {
306350 operation.getDefaultAsyncMutable ().append (
0 commit comments