Skip to content

Commit e7ee737

Browse files
committed
[OpenACC][CIR] Implement 'device_type' clause lowering for 'init'/'shutdown'
This patch emits the lowering for 'device_type' on an 'init' or 'shutdown'. This one is fairly unique, as these directives have it as an attribute, rather than as a component of the individual operands, like the rest of the constructs. So this patch implements the lowering as an attribute. In order to do tis, a few refactorings had to happen: First, the 'emitOpenACCOp' functions needed to pick up th edirective kind/location so that the NYI diagnostic could be reasonable. Second, and most impactful, the `applyAttributes` function ends up needing to encode some of the appertainment rules, thanks to the way the OpenACC-MLIR operands get their attributes attached. Since they each use a special function (rather than something that can be legalized at runtime), the forms of 'setDefaultAttr' is only valid for some ops. SO this patch uses some `if constexpr` and a small type-trait to help legalize these.
1 parent 9b50167 commit e7ee737

File tree

4 files changed

+135
-22
lines changed

4 files changed

+135
-22
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -585,15 +585,16 @@ class CIRGenFunction : public CIRGenTypeCache {
585585
private:
586586
template <typename Op>
587587
mlir::LogicalResult
588-
emitOpenACCOp(mlir::Location start,
588+
emitOpenACCOp(OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
589+
mlir::Location start,
589590
llvm::ArrayRef<const OpenACCClause *> clauses);
590591
// Function to do the basic implementation of an operation with an Associated
591592
// Statement. Models AssociatedStmtConstruct.
592593
template <typename Op, typename TermOp>
593-
mlir::LogicalResult
594-
emitOpenACCOpAssociatedStmt(mlir::Location start, mlir::Location end,
595-
llvm::ArrayRef<const OpenACCClause *> clauses,
596-
const Stmt *associatedStmt);
594+
mlir::LogicalResult emitOpenACCOpAssociatedStmt(
595+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc, mlir::Location start,
596+
mlir::Location end, llvm::ArrayRef<const OpenACCClause *> clauses,
597+
const Stmt *associatedStmt);
597598

598599
public:
599600
mlir::LogicalResult

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 103 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// Emit OpenACC Stmt nodes as CIR code.
1010
//
1111
//===----------------------------------------------------------------------===//
12+
#include <type_traits>
1213

1314
#include "CIRGenBuilder.h"
1415
#include "CIRGenFunction.h"
@@ -23,22 +24,39 @@ using namespace cir;
2324
using namespace mlir::acc;
2425

2526
namespace {
27+
// Simple type-trait to see if the first template arg is one of the list, so we
28+
// can tell whether to `if-constexpr` a bunch of stuff.
29+
template <typename ToTest, typename T, typename... Tys>
30+
constexpr bool isOneOfTypes =
31+
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
32+
template <typename ToTest, typename T>
33+
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
34+
2635
class OpenACCClauseCIREmitter final
2736
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
2837
CIRGenModule &cgm;
38+
// This is necessary since a few of the clauses emit differently based on the
39+
// directive kind they are attached to.
40+
OpenACCDirectiveKind dirKind;
41+
SourceLocation dirLoc;
2942

3043
struct AttributeData {
3144
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
3245
// constructs as a 'default-attr'.
3346
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
47+
// For directives that have their device type architectures listed in
48+
// attributes (init/shutdown/etc), the list of architectures to be emitted.
49+
llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
3450
} attrData;
3551

3652
void clauseNotImplemented(const OpenACCClause &c) {
3753
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
3854
}
3955

4056
public:
41-
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {}
57+
OpenACCClauseCIREmitter(CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
58+
SourceLocation dirLoc)
59+
: cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
4260

4361
void VisitClause(const OpenACCClause &clause) {
4462
clauseNotImplemented(clause);
@@ -57,31 +75,90 @@ class OpenACCClauseCIREmitter final
5775
}
5876
}
5977

78+
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *II) {
79+
80+
// '*' case leaves no identifier-info, just a nullptr.
81+
if (!II)
82+
return mlir::acc::DeviceType::Star;
83+
return llvm::StringSwitch<mlir::acc::DeviceType>(II->getName())
84+
.CaseLower("default", mlir::acc::DeviceType::Default)
85+
.CaseLower("host", mlir::acc::DeviceType::Host)
86+
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
87+
.CasesLower("nvidia", "acc_device_nvidia",
88+
mlir::acc::DeviceType::Nvidia)
89+
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
90+
}
91+
92+
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
93+
94+
switch (dirKind) {
95+
case OpenACCDirectiveKind::Init:
96+
case OpenACCDirectiveKind::Shutdown: {
97+
// Device type has a list that is either a 'star' (emitted as 'star'),
98+
// or an identifer list, all of which get added for attributes.
99+
100+
for (const DeviceTypeArgument &Arg : clause.getArchitectures())
101+
attrData.deviceTypeArchs.push_back(decodeDeviceType(Arg.first));
102+
break;
103+
}
104+
default:
105+
return clauseNotImplemented(clause);
106+
}
107+
}
108+
60109
// Apply any of the clauses that resulted in an 'attribute'.
61-
template <typename Op> void applyAttributes(Op &op) {
62-
if (attrData.defaultVal.has_value())
63-
op.setDefaultAttr(*attrData.defaultVal);
110+
template <typename Op>
111+
void applyAttributes(CIRGenBuilderTy &builder, Op &op) {
112+
113+
if (attrData.defaultVal.has_value()) {
114+
// FIXME: OpenACC: as we implement this for other directive kinds, we have
115+
// to expand this list.
116+
if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
117+
op.setDefaultAttr(*attrData.defaultVal);
118+
else
119+
cgm.errorNYI(dirLoc, "OpenACC 'default' clause lowering for ", dirKind);
120+
}
121+
122+
if (!attrData.deviceTypeArchs.empty()) {
123+
// FIXME: OpenACC: as we implement this for other directive kinds, we have
124+
// to expand this list, or more likely, have a 'noop' branch as most other
125+
// uses of this apply to the operands instead.
126+
if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
127+
llvm::SmallVector<mlir::Attribute> deviceTypes;
128+
for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs)
129+
deviceTypes.push_back(
130+
mlir::acc::DeviceTypeAttr::get(builder.getContext(), DT));
131+
132+
op.setDeviceTypesAttr(
133+
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
134+
} else {
135+
cgm.errorNYI(dirLoc, "OpenACC 'device_type' clause lowering for ",
136+
dirKind);
137+
}
138+
}
64139
}
65140
};
141+
66142
} // namespace
67143

68144
template <typename Op, typename TermOp>
69145
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
70-
mlir::Location start, mlir::Location end,
71-
llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
146+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc, mlir::Location start,
147+
mlir::Location end, llvm::ArrayRef<const OpenACCClause *> clauses,
148+
const Stmt *associatedStmt) {
72149
mlir::LogicalResult res = mlir::success();
73150

74151
llvm::SmallVector<mlir::Type> retTy;
75152
llvm::SmallVector<mlir::Value> operands;
76153

77154
// Clause-emitter must be here because it might modify operands.
78-
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
155+
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
79156
clauseEmitter.VisitClauseList(clauses);
80157

81158
auto op = builder.create<Op>(start, retTy, operands);
82159

83160
// Apply the attributes derived from the clauses.
84-
clauseEmitter.applyAttributes(op);
161+
clauseEmitter.applyAttributes(builder, op);
85162

86163
mlir::Block &block = op.getRegion().emplaceBlock();
87164
mlir::OpBuilder::InsertionGuard guardCase(builder);
@@ -96,18 +173,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
96173

97174
template <typename Op>
98175
mlir::LogicalResult
99-
CIRGenFunction::emitOpenACCOp(mlir::Location start,
176+
CIRGenFunction::emitOpenACCOp(OpenACCDirectiveKind dirKind,
177+
SourceLocation dirLoc, mlir::Location start,
100178
llvm::ArrayRef<const OpenACCClause *> clauses) {
101179
mlir::LogicalResult res = mlir::success();
102180

103181
llvm::SmallVector<mlir::Type> retTy;
104182
llvm::SmallVector<mlir::Value> operands;
105183

106184
// Clause-emitter must be here because it might modify operands.
107-
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
185+
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
108186
clauseEmitter.VisitClauseList(clauses);
109187

110-
builder.create<Op>(start, retTy, operands);
188+
auto op = builder.create<Op>(start, retTy, operands);
189+
// Apply the attributes derived from the clauses.
190+
clauseEmitter.applyAttributes(builder, op);
111191
return res;
112192
}
113193

@@ -119,13 +199,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
119199
switch (s.getDirectiveKind()) {
120200
case OpenACCDirectiveKind::Parallel:
121201
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
122-
start, end, s.clauses(), s.getStructuredBlock());
202+
s.getDirectiveKind(), s.getDirectiveLoc(), start, end, s.clauses(),
203+
s.getStructuredBlock());
123204
case OpenACCDirectiveKind::Serial:
124205
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
125-
start, end, s.clauses(), s.getStructuredBlock());
206+
s.getDirectiveKind(), s.getDirectiveLoc(), start, end, s.clauses(),
207+
s.getStructuredBlock());
126208
case OpenACCDirectiveKind::Kernels:
127209
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
128-
start, end, s.clauses(), s.getStructuredBlock());
210+
s.getDirectiveKind(), s.getDirectiveLoc(), start, end, s.clauses(),
211+
s.getStructuredBlock());
129212
default:
130213
llvm_unreachable("invalid compute construct kind");
131214
}
@@ -137,18 +220,21 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
137220
mlir::Location end = getLoc(s.getSourceRange().getEnd());
138221

139222
return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
140-
start, end, s.clauses(), s.getStructuredBlock());
223+
s.getDirectiveKind(), s.getDirectiveLoc(), start, end, s.clauses(),
224+
s.getStructuredBlock());
141225
}
142226

143227
mlir::LogicalResult
144228
CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
145229
mlir::Location start = getLoc(s.getSourceRange().getEnd());
146-
return emitOpenACCOp<InitOp>(start, s.clauses());
230+
return emitOpenACCOp<InitOp>(s.getDirectiveKind(), s.getDirectiveLoc(), start,
231+
s.clauses());
147232
}
148233
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
149234
const OpenACCShutdownConstruct &s) {
150235
mlir::Location start = getLoc(s.getSourceRange().getEnd());
151-
return emitOpenACCOp<ShutdownOp>(start, s.clauses());
236+
return emitOpenACCOp<ShutdownOp>(s.getDirectiveKind(), s.getDirectiveLoc(),
237+
start, s.clauses());
152238
}
153239

154240
mlir::LogicalResult

clang/test/CIR/CodeGenOpenACC/init.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,17 @@ void acc_init(void) {
44
// CHECK: cir.func @acc_init() {
55
#pragma acc init
66
// CHECK-NEXT: acc.init loc(#{{[a-zA-Z0-9]+}}){{$}}
7+
8+
#pragma acc init device_type(*)
9+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<star>]}
10+
#pragma acc init device_type(nvidia)
11+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
12+
#pragma acc init device_type(host, multicore)
13+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
14+
#pragma acc init device_type(NVIDIA)
15+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
16+
#pragma acc init device_type(HoSt, MuLtIcORe)
17+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
18+
#pragma acc init device_type(HoSt) device_type(MuLtIcORe)
19+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
720
}

clang/test/CIR/CodeGenOpenACC/shutdown.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,17 @@ void acc_shutdown(void) {
44
// CHECK: cir.func @acc_shutdown() {
55
#pragma acc shutdown
66
// CHECK-NEXT: acc.shutdown loc(#{{[a-zA-Z0-9]+}}){{$}}
7+
8+
#pragma acc shutdown device_type(*)
9+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<star>]}
10+
#pragma acc shutdown device_type(nvidia)
11+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
12+
#pragma acc shutdown device_type(host, multicore)
13+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
14+
#pragma acc shutdown device_type(NVIDIA)
15+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
16+
#pragma acc shutdown device_type(HoSt, MuLtIcORe)
17+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
18+
#pragma acc shutdown device_type(HoSt) device_type(MuLtIcORe)
19+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
720
}

0 commit comments

Comments
 (0)