@@ -46,10 +46,27 @@ class OpenACCClauseCIREmitter final
4646 // diagnostics are gone.
4747 SourceLocation dirLoc;
4848
49+ const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr ;
50+
4951 void clauseNotImplemented (const OpenACCClause &c) {
5052 cgf.cgm .errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
5153 }
5254
55+ mlir::Value createIntExpr (const Expr *intExpr) {
56+ mlir::Value expr = cgf.emitScalarExpr (intExpr);
57+ mlir::Location exprLoc = cgf.cgm .getLoc (intExpr->getBeginLoc ());
58+
59+ mlir::IntegerType targetType = mlir::IntegerType::get (
60+ &cgf.getMLIRContext (), cgf.getContext ().getIntWidth (intExpr->getType ()),
61+ intExpr->getType ()->isSignedIntegerOrEnumerationType ()
62+ ? mlir::IntegerType::SignednessSemantics::Signed
63+ : mlir::IntegerType::SignednessSemantics::Unsigned);
64+
65+ auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
66+ exprLoc, targetType, expr);
67+ return conversionOp.getResult (0 );
68+ }
69+
5370 // 'condition' as an OpenACC grammar production is used for 'if' and (some
5471 // variants of) 'self'. It needs to be emitted as a signless-1-bit value, so
5572 // this function emits the expression, then sets the unrealized conversion
@@ -65,6 +82,56 @@ class OpenACCClauseCIREmitter final
6582 return conversionOp.getResult (0 );
6683 }
6784
85+ mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *ii) {
86+ // '*' case leaves no identifier-info, just a nullptr.
87+ if (!ii)
88+ return mlir::acc::DeviceType::Star;
89+ return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName ())
90+ .CaseLower (" default" , mlir::acc::DeviceType::Default)
91+ .CaseLower (" host" , mlir::acc::DeviceType::Host)
92+ .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
93+ .CasesLower (" nvidia" , " acc_device_nvidia" ,
94+ mlir::acc::DeviceType::Nvidia)
95+ .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
96+ }
97+
98+ // Handle a clause affected by the 'device-type' to the point that they need
99+ // to have the attributes added in the correct/corresponding order, such as
100+ // 'num_workers' or 'vector_length' on a compute construct.
101+ mlir::ArrayAttr
102+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
103+ mlir::Value argument,
104+ mlir::MutableOperandRange &argCollection) {
105+ llvm::SmallVector<mlir::Attribute> deviceTypes;
106+
107+ // Collect the 'existing' device-type attributes so we can re-create them
108+ // and insert them.
109+ if (existingDeviceTypes) {
110+ for (const mlir::Attribute &Attr : existingDeviceTypes)
111+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
112+ builder.getContext (),
113+ cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
114+ }
115+
116+ // Insert 1 version of the 'expr' to the NumWorkers list per-current
117+ // device type.
118+ if (lastDeviceTypeClause) {
119+ for (const DeviceTypeArgument &arch :
120+ lastDeviceTypeClause->getArchitectures ()) {
121+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
122+ builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
123+ argCollection.append (argument);
124+ }
125+ } else {
126+ // Else, we just add a single for 'none'.
127+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
128+ builder.getContext (), mlir::acc::DeviceType::None));
129+ argCollection.append (argument);
130+ }
131+
132+ return mlir::ArrayAttr::get (builder.getContext (), deviceTypes);
133+ }
134+
68135public:
69136 OpenACCClauseCIREmitter (OpTy &operation, CIRGenFunction &cgf,
70137 CIRGenBuilderTy &builder,
@@ -95,31 +162,19 @@ class OpenACCClauseCIREmitter final
95162 }
96163 }
97164
98- mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *ii) {
99- // '*' case leaves no identifier-info, just a nullptr.
100- if (!ii)
101- return mlir::acc::DeviceType::Star;
102- return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName ())
103- .CaseLower (" default" , mlir::acc::DeviceType::Default)
104- .CaseLower (" host" , mlir::acc::DeviceType::Host)
105- .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
106- .CasesLower (" nvidia" , " acc_device_nvidia" ,
107- mlir::acc::DeviceType::Nvidia)
108- .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
109- }
110-
111165 void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
166+ lastDeviceTypeClause = &clause;
112167 if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
113168 llvm::SmallVector<mlir::Attribute> deviceTypes;
114169 std::optional<mlir::ArrayAttr> existingDeviceTypes =
115170 operation.getDeviceTypes ();
116171
117172 // Ensure we keep the existing ones, and in the correct 'new' order.
118173 if (existingDeviceTypes) {
119- for (const mlir::Attribute &Attr : *existingDeviceTypes)
174+ for (mlir::Attribute attr : *existingDeviceTypes)
120175 deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
121176 builder.getContext (),
122- cast<mlir::acc::DeviceTypeAttr>(Attr ).getValue ()));
177+ cast<mlir::acc::DeviceTypeAttr>(attr ).getValue ()));
123178 }
124179
125180 for (const DeviceTypeArgument &arg : clause.getArchitectures ()) {
@@ -136,6 +191,36 @@ class OpenACCClauseCIREmitter final
136191 if (!clause.getArchitectures ().empty ())
137192 operation.setDeviceType (
138193 decodeDeviceType (clause.getArchitectures ()[0 ].getIdentifierInfo ()));
194+ } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp>) {
195+ // Nothing to do here, these constructs don't have any IR for these, as
196+ // they just modify the other clauses IR. So setting of `lastDeviceType`
197+ // (done above) is all we need.
198+ } else {
199+ return clauseNotImplemented (clause);
200+ }
201+ }
202+
203+ void VisitNumWorkersClause (const OpenACCNumWorkersClause &clause) {
204+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
205+ mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
206+ operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
207+ operation.getNumWorkersDeviceTypeAttr (),
208+ createIntExpr (clause.getIntExpr ()), range));
209+ } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
210+ llvm_unreachable (" num_workers not valid on serial" );
211+ } else {
212+ return clauseNotImplemented (clause);
213+ }
214+ }
215+
216+ void VisitVectorLengthClause (const OpenACCVectorLengthClause &clause) {
217+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
218+ mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
219+ operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
220+ operation.getVectorLengthDeviceTypeAttr (),
221+ createIntExpr (clause.getIntExpr ()), range));
222+ } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
223+ llvm_unreachable (" vector_length not valid on serial" );
139224 } else {
140225 return clauseNotImplemented (clause);
141226 }
0 commit comments