Skip to content

Commit 78ab389

Browse files
authored
[OpenACC][CIR] Implement 'gang' lowering on `routine' (#170506)
This is a bit more work than the worker/vector/seq in that gang takes an optional `dim` argument. The argument is always 1, 2, or 3 (constants!), and the other argument-types that gang allows elsewhere aren't valid here. For the IR, we had to add 2 overloads of `addGang`. The first just adds the 'valueless' one, which can just add to the one ArrayAttr. The second has to add to TWO lists. Note: The standard limits to only 1 `gang` per construct. We decided after evaluating it, that it really means 'per device-type region'. However, device_type isn't implemented yet, so we'll add tests for that when we do. At the moment, we added the device_type infrastructure however.
1 parent ed6078c commit 78ab389

File tree

4 files changed

+106
-3
lines changed

4 files changed

+106
-3
lines changed

clang/lib/CIR/CodeGen/CIRGenDeclOpenACC.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,16 @@ void CIRGenModule::emitGlobalOpenACCRoutineDecl(const OpenACCRoutineDecl *d) {
303303
namespace {
304304
class OpenACCRoutineClauseEmitter final
305305
: public OpenACCClauseVisitor<OpenACCRoutineClauseEmitter> {
306+
CIRGenModule &cgm;
306307
CIRGen::CIRGenBuilderTy &builder;
307308
mlir::acc::RoutineOp routineOp;
308309
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
309310

310311
public:
311-
OpenACCRoutineClauseEmitter(CIRGen::CIRGenBuilderTy &builder,
312+
OpenACCRoutineClauseEmitter(CIRGenModule &cgm,
313+
CIRGen::CIRGenBuilderTy &builder,
312314
mlir::acc::RoutineOp routineOp)
313-
: builder(builder), routineOp(routineOp) {}
315+
: cgm(cgm), builder(builder), routineOp(routineOp) {}
314316

315317
void emitClauses(ArrayRef<const OpenACCClause *> clauses) {
316318
this->VisitClauseList(clauses);
@@ -333,6 +335,26 @@ class OpenACCRoutineClauseEmitter final
333335
void VisitNoHostClause(const OpenACCNoHostClause &clause) {
334336
routineOp.setNohost(/*attrValue=*/true);
335337
}
338+
339+
void VisitGangClause(const OpenACCGangClause &clause) {
340+
// Gang has an optional 'dim' value, which is a constant int of 1, 2, or 3.
341+
// If we don't store any expressions in the clause, there are none, else we
342+
// expect there is 1, since Sema should enforce that the single 'dim' is the
343+
// only valid value.
344+
if (clause.getNumExprs() == 0) {
345+
routineOp.addGang(builder.getContext(), lastDeviceTypeValues);
346+
} else {
347+
assert(clause.getNumExprs() == 1);
348+
auto [kind, expr] = clause.getExpr(0);
349+
assert(kind == OpenACCGangKind::Dim);
350+
351+
llvm::APSInt curValue = expr->EvaluateKnownConstInt(cgm.getASTContext());
352+
// The value is 1, 2, or 3, but 64 bit seems right enough.
353+
curValue = curValue.sextOrTrunc(64);
354+
routineOp.addGang(builder.getContext(), lastDeviceTypeValues,
355+
curValue.getZExtValue());
356+
}
357+
}
336358
};
337359
} // namespace
338360

@@ -373,6 +395,6 @@ void CIRGenModule::emitOpenACCRoutineDecl(
373395
mlir::acc::getRoutineInfoAttrName(),
374396
mlir::acc::RoutineInfoAttr::get(func.getContext(), funcRoutines));
375397

376-
OpenACCRoutineClauseEmitter emitter{builder, routineOp};
398+
OpenACCRoutineClauseEmitter emitter{*this, builder, routineOp};
377399
emitter.emitClauses(clauses);
378400
}

clang/test/CIR/CodeGenOpenACC/routine-clauses.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,27 @@ void Func5() {}
1818
void Func6() {}
1919
#pragma acc routine(Func6) nohost vector
2020

21+
#pragma acc routine gang
22+
void Func7() {}
23+
24+
void Func8() {}
25+
#pragma acc routine(Func8) gang
26+
27+
#pragma acc routine gang(dim:1)
28+
void Func9() {}
29+
30+
void Func10() {}
31+
#pragma acc routine(Func10) gang(dim:3)
32+
33+
constexpr int Value = 2;
34+
35+
#pragma acc routine gang(dim:Value) nohost
36+
void Func11() {}
37+
38+
39+
void Func12() {}
40+
#pragma acc routine(Func12) nohost gang(dim:Value)
41+
2142
// CHECK: cir.func{{.*}} @[[F1_NAME:.*Func1[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F1_R_NAME:.*]]]>}
2243
// CHECK: acc.routine @[[F1_R_NAME]] func(@[[F1_NAME]]) seq nohost
2344

@@ -32,7 +53,25 @@ void Func6() {}
3253
// CHECK: acc.routine @[[F5_R_NAME]] func(@[[F5_NAME]]) vector
3354

3455
// CHECK: cir.func{{.*}} @[[F6_NAME:.*Func6[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F6_R_NAME:.*]]]>}
56+
//
57+
// CHECK: cir.func{{.*}} @[[F7_NAME:.*Func7[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F7_R_NAME:.*]]]>}
58+
// CHECK: acc.routine @[[F7_R_NAME]] func(@[[F7_NAME]]) gang
59+
//
60+
// CHECK: cir.func{{.*}} @[[F8_NAME:.*Func8[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F8_R_NAME:.*]]]>}
61+
//
62+
// CHECK: cir.func{{.*}} @[[F9_NAME:.*Func9[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F9_R_NAME:.*]]]>}
63+
// CHECK: acc.routine @[[F9_R_NAME]] func(@[[F9_NAME]]) gang(dim: 1 : i64)
64+
//
65+
// CHECK: cir.func{{.*}} @[[F10_NAME:.*Func10[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F10_R_NAME:.*]]]>}
66+
67+
// CHECK: cir.func{{.*}} @[[F11_NAME:.*Func11[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F11_R_NAME:.*]]]>}
68+
// CHECK: acc.routine @[[F11_R_NAME]] func(@[[F11_NAME]]) gang(dim: 2 : i64)
69+
//
70+
// CHECK: cir.func{{.*}} @[[F12_NAME:.*Func12[^\(]*]]({{.*}}){{.*}} attributes {acc.routine_info = #acc.routine_info<[@[[F12_R_NAME:.*]]]>}
3571

3672
// CHECK: acc.routine @[[F2_R_NAME]] func(@[[F2_NAME]]) seq
3773
// CHECK: acc.routine @[[F4_R_NAME]] func(@[[F4_NAME]]) worker nohost
3874
// CHECK: acc.routine @[[F6_R_NAME]] func(@[[F6_NAME]]) vector nohost
75+
// CHECK: acc.routine @[[F8_R_NAME]] func(@[[F8_NAME]]) gang
76+
// CHECK: acc.routine @[[F10_R_NAME]] func(@[[F10_NAME]]) gang(dim: 3 : i64)
77+
// CHECK: acc.routine @[[F12_R_NAME]] func(@[[F12_NAME]]) gang(dim: 2 : i64)

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,6 +3300,11 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
33003300
void addVector(MLIRContext *, llvm::ArrayRef<DeviceType>);
33013301
// Add an entry to the 'worker' attribute for each additional device types.
33023302
void addWorker(MLIRContext *, llvm::ArrayRef<DeviceType>);
3303+
// Add an entry to the 'gang' attribute for each additional device type.
3304+
void addGang(MLIRContext *, llvm::ArrayRef<DeviceType>);
3305+
// Add an entry to the 'gang' attribute with a value for each additional
3306+
// device type.
3307+
void addGang(MLIRContext *, llvm::ArrayRef<DeviceType>, uint64_t);
33033308
}];
33043309

33053310
let assemblyFormat = [{

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4388,6 +4388,43 @@ void RoutineOp::addWorker(MLIRContext *context,
43884388
effectiveDeviceTypes));
43894389
}
43904390

4391+
void RoutineOp::addGang(MLIRContext *context,
4392+
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4393+
setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4394+
effectiveDeviceTypes));
4395+
}
4396+
4397+
void RoutineOp::addGang(MLIRContext *context,
4398+
llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4399+
uint64_t val) {
4400+
llvm::SmallVector<mlir::Attribute> dimValues;
4401+
llvm::SmallVector<mlir::Attribute> deviceTypes;
4402+
4403+
if (getGangDimAttr())
4404+
llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4405+
if (getGangDimDeviceTypeAttr())
4406+
llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4407+
4408+
assert(dimValues.size() == deviceTypes.size());
4409+
4410+
if (effectiveDeviceTypes.empty()) {
4411+
dimValues.push_back(
4412+
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4413+
deviceTypes.push_back(
4414+
acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4415+
} else {
4416+
for (DeviceType dt : effectiveDeviceTypes) {
4417+
dimValues.push_back(
4418+
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4419+
deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4420+
}
4421+
}
4422+
assert(dimValues.size() == deviceTypes.size());
4423+
4424+
setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4425+
setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4426+
}
4427+
43914428
//===----------------------------------------------------------------------===//
43924429
// InitOp
43934430
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)