Skip to content

Commit f24d90d

Browse files
committed
[OpenACC][CIR] Implement 'num_gangs' lowering
This is similar to the previous handful of lowering commits, except that it takes an array of int-expressions rather than a single one. This complicates the list of things that need updating (as the 'segments' array also needs updating), which resulted in a bit of a refactor. At the moment, only parallel/kernels are enabled (not parallel loop/kernels loop), so tests are added just for those.
1 parent d859cb6 commit f24d90d

File tree

3 files changed

+186
-22
lines changed

3 files changed

+186
-22
lines changed

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,41 @@ class OpenACCClauseCIREmitter final
9595
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
9696
}
9797

98-
// Handle a clause affected by the 'device-type' to the point that they need
99-
// to have the attributes added in the correct/corresponding order, such as
100-
// 'num_workers' or 'vector_length' on a compute construct. For cases where we
101-
// don't have an expression 'argument' that needs to be added to an operand
102-
// and only care about the 'device-type' list, we can use this with 'argument'
103-
// as 'std::nullopt'. If 'argument' is NOT 'std::nullopt' (that is, has a
104-
// value), argCollection must also be non-null. For cases where we don't have
105-
// an argument that needs to be added to an additional one (such as asyncOnly)
106-
// we can use this with 'argument' as std::nullopt.
107-
mlir::ArrayAttr handleDeviceTypeAffectedClause(
108-
mlir::ArrayAttr existingDeviceTypes,
109-
std::optional<mlir::Value> argument = std::nullopt,
110-
mlir::MutableOperandRange *argCollection = nullptr) {
98+
// Overload of this function that only returns the device-types list.
99+
mlir::ArrayAttr
100+
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) {
101+
mlir::ValueRange argument;
102+
mlir::MutableOperandRange range{operation};
103+
104+
return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range);
105+
}
106+
// Overload of this function for when 'segments' aren't necessary.
107+
mlir::ArrayAttr
108+
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
109+
mlir::ValueRange argument,
110+
mlir::MutableOperandRange argCollection) {
111+
llvm::SmallVector<int32_t> segments;
112+
assert(argument.size() <= 1 &&
113+
"Overload only for cases where segments don't need to be added");
114+
return handleDeviceTypeAffectedClause(existingDeviceTypes, argument,
115+
argCollection, segments);
116+
}
117+
118+
// Handle a clause affected by the 'device_type' to the point that they need
119+
// to have attributes added in the correct/corresponding order, such as
120+
// 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121+
// a collection of operands that need to be appended to the `argCollection` as
122+
// we're adding a 'device_type' entry. If there is more than 0 elements in
123+
// the 'argument', the collection must be non-null, as it is needed to add to
124+
// it.
125+
// As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126+
// be maintained, this takes a list of segments that will be updated with the
127+
// proper counts as 'argument' elements are added.
128+
mlir::ArrayAttr
129+
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
130+
mlir::ValueRange argument,
131+
mlir::MutableOperandRange argCollection,
132+
llvm::SmallVector<int32_t> &segments) {
111133
llvm::SmallVector<mlir::Attribute> deviceTypes;
112134

113135
// Collect the 'existing' device-type attributes so we can re-create them
@@ -126,18 +148,18 @@ class OpenACCClauseCIREmitter final
126148
lastDeviceTypeClause->getArchitectures()) {
127149
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
128150
builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
129-
if (argument) {
130-
assert(argCollection);
131-
argCollection->append(*argument);
151+
if (!argument.empty()) {
152+
argCollection.append(argument);
153+
segments.push_back(argument.size());
132154
}
133155
}
134156
} else {
135157
// Else, we just add a single for 'none'.
136158
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
137159
builder.getContext(), mlir::acc::DeviceType::None));
138-
if (argument) {
139-
assert(argCollection);
140-
argCollection->append(*argument);
160+
if (!argument.empty()) {
161+
argCollection.append(argument);
162+
segments.push_back(argument.size());
141163
}
142164
}
143165

@@ -220,7 +242,7 @@ class OpenACCClauseCIREmitter final
220242
mlir::MutableOperandRange range = operation.getNumWorkersMutable();
221243
operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
222244
operation.getNumWorkersDeviceTypeAttr(),
223-
createIntExpr(clause.getIntExpr()), &range));
245+
createIntExpr(clause.getIntExpr()), range));
224246
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
225247
llvm_unreachable("num_workers not valid on serial");
226248
} else {
@@ -234,7 +256,7 @@ class OpenACCClauseCIREmitter final
234256
mlir::MutableOperandRange range = operation.getVectorLengthMutable();
235257
operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
236258
operation.getVectorLengthDeviceTypeAttr(),
237-
createIntExpr(clause.getIntExpr()), &range));
259+
createIntExpr(clause.getIntExpr()), range));
238260
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
239261
llvm_unreachable("vector_length not valid on serial");
240262
} else {
@@ -252,7 +274,7 @@ class OpenACCClauseCIREmitter final
252274
mlir::MutableOperandRange range = operation.getAsyncOperandsMutable();
253275
operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
254276
operation.getAsyncOperandsDeviceTypeAttr(),
255-
createIntExpr(clause.getIntExpr()), &range));
277+
createIntExpr(clause.getIntExpr()), range));
256278
}
257279
} else {
258280
// Data, enter data, exit data, update, wait, combined remain.
@@ -301,6 +323,28 @@ class OpenACCClauseCIREmitter final
301323
}
302324
}
303325

326+
void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
327+
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
328+
llvm::SmallVector<mlir::Value> values;
329+
330+
for (const Expr *E : clause.getIntExprs())
331+
values.push_back(createIntExpr(E));
332+
333+
llvm::SmallVector<int32_t> segments;
334+
if (operation.getNumGangsSegments())
335+
llvm::copy(*operation.getNumGangsSegments(),
336+
std::back_inserter(segments));
337+
338+
mlir::MutableOperandRange range = operation.getNumGangsMutable();
339+
operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause(
340+
operation.getNumGangsDeviceTypeAttr(), values, range, segments));
341+
operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments});
342+
} else {
343+
// combined remains.
344+
return clauseNotImplemented(clause);
345+
}
346+
}
347+
304348
void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
305349
if constexpr (isOneOfTypes<OpTy, SetOp>) {
306350
operation.getDefaultAsyncMutable().append(

clang/test/CIR/CodeGenOpenACC/kernels.c

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,5 +256,51 @@ void acc_kernels(int cond) {
256256
// CHECK-NEXT: acc.terminator
257257
// CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<nvidia>, #acc.device_type<radeon>]}
258258

259+
#pragma acc kernels num_gangs(1)
260+
{}
261+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
262+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
263+
// CHECK-NEXT: acc.kernels num_gangs({%[[ONE_CAST]] : si32}) {
264+
// CHECK-NEXT: acc.terminator
265+
// CHECK-NEXT: } loc
266+
267+
#pragma acc kernels num_gangs(cond)
268+
{}
269+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
270+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
271+
// CHECK-NEXT: acc.kernels num_gangs({%[[CONV_CAST]] : si32}) {
272+
// CHECK-NEXT: acc.terminator
273+
// CHECK-NEXT: } loc
274+
275+
#pragma acc kernels num_gangs(1) device_type(radeon) num_gangs(cond)
276+
{}
277+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
278+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
279+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
280+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
281+
// CHECK-NEXT: acc.kernels num_gangs({%[[ONE_CAST]] : si32}, {%[[CONV_CAST]] : si32} [#acc.device_type<radeon>]) {
282+
// CHECK-NEXT: acc.terminator
283+
// CHECK-NEXT: } loc
284+
285+
#pragma acc kernels num_gangs(1) device_type(radeon) num_gangs(6)
286+
{}
287+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
288+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
289+
// CHECK-NEXT: %[[SIX_LITERAL:.*]] = cir.const #cir.int<6> : !s32i
290+
// CHECK-NEXT: %[[SIX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIX_LITERAL]] : !s32i to si32
291+
// CHECK-NEXT: acc.kernels num_gangs({%[[ONE_CAST]] : si32}, {%[[SIX_CAST]] : si32} [#acc.device_type<radeon>]) {
292+
// CHECK-NEXT: acc.terminator
293+
// CHECK-NEXT: } loc
294+
295+
#pragma acc kernels num_gangs(cond) device_type(radeon, nvidia) num_gangs(4)
296+
{}
297+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
298+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
299+
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
300+
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
301+
// CHECK-NEXT: acc.kernels num_gangs({%[[CONV_CAST]] : si32}, {%[[FOUR_CAST]] : si32} [#acc.device_type<radeon>], {%[[FOUR_CAST]] : si32} [#acc.device_type<nvidia>]) {
302+
// CHECK-NEXT: acc.terminator
303+
// CHECK-NEXT: } loc
304+
259305
// CHECK-NEXT: cir.return
260306
}

clang/test/CIR/CodeGenOpenACC/parallel.c

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,5 +255,79 @@ void acc_parallel(int cond) {
255255
// CHECK-NEXT: acc.yield
256256
// CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<nvidia>, #acc.device_type<radeon>]}
257257

258+
#pragma acc parallel num_gangs(1)
259+
{}
260+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
261+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
262+
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32}) {
263+
// CHECK-NEXT: acc.yield
264+
// CHECK-NEXT: } loc
265+
266+
#pragma acc parallel num_gangs(cond)
267+
{}
268+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
269+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
270+
// CHECK-NEXT: acc.parallel num_gangs({%[[CONV_CAST]] : si32}) {
271+
// CHECK-NEXT: acc.yield
272+
// CHECK-NEXT: } loc
273+
274+
#pragma acc parallel num_gangs(1, cond, 2)
275+
{}
276+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
277+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
278+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
279+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
280+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
281+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
282+
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32, %[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32}) {
283+
// CHECK-NEXT: acc.yield
284+
// CHECK-NEXT: } loc
285+
286+
#pragma acc parallel num_gangs(1) device_type(radeon) num_gangs(cond)
287+
{}
288+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
289+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
290+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
291+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
292+
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32}, {%[[CONV_CAST]] : si32} [#acc.device_type<radeon>]) {
293+
// CHECK-NEXT: acc.yield
294+
// CHECK-NEXT: } loc
295+
296+
#pragma acc parallel num_gangs(1, cond, 2) device_type(radeon) num_gangs(4, 5, 6)
297+
{}
298+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
299+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
300+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
301+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
302+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
303+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
304+
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
305+
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
306+
// CHECK-NEXT: %[[FIVE_LITERAL:.*]] = cir.const #cir.int<5> : !s32i
307+
// CHECK-NEXT: %[[FIVE_CAST:.*]] = builtin.unrealized_conversion_cast %[[FIVE_LITERAL]] : !s32i to si32
308+
// CHECK-NEXT: %[[SIX_LITERAL:.*]] = cir.const #cir.int<6> : !s32i
309+
// CHECK-NEXT: %[[SIX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIX_LITERAL]] : !s32i to si32
310+
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32, %[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32}, {%[[FOUR_CAST]] : si32, %[[FIVE_CAST]] : si32, %[[SIX_CAST]] : si32} [#acc.device_type<radeon>])
311+
// CHECK-NEXT: acc.yield
312+
// CHECK-NEXT: } loc
313+
314+
#pragma acc parallel num_gangs(1, cond, 2) device_type(radeon, nvidia) num_gangs(4, 5, 6)
315+
{}
316+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
317+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
318+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
319+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
320+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
321+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
322+
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
323+
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
324+
// CHECK-NEXT: %[[FIVE_LITERAL:.*]] = cir.const #cir.int<5> : !s32i
325+
// CHECK-NEXT: %[[FIVE_CAST:.*]] = builtin.unrealized_conversion_cast %[[FIVE_LITERAL]] : !s32i to si32
326+
// CHECK-NEXT: %[[SIX_LITERAL:.*]] = cir.const #cir.int<6> : !s32i
327+
// CHECK-NEXT: %[[SIX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIX_LITERAL]] : !s32i to si32
328+
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32, %[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32}, {%[[FOUR_CAST]] : si32, %[[FIVE_CAST]] : si32, %[[SIX_CAST]] : si32} [#acc.device_type<radeon>], {%[[FOUR_CAST]] : si32, %[[FIVE_CAST]] : si32, %[[SIX_CAST]] : si32} [#acc.device_type<nvidia>])
329+
// CHECK-NEXT: acc.yield
330+
// CHECK-NEXT: } loc
331+
258332
// CHECK-NEXT: cir.return
259333
}

0 commit comments

Comments
 (0)