@@ -46,10 +46,27 @@ class OpenACCClauseCIREmitter final
4646 // diagnostics are gone.
4747 SourceLocation dirLoc;
4848
49+ const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr ;
50+
4951 void clauseNotImplemented (const OpenACCClause &c) {
5052 cgf.cgm .errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
5153 }
5254
55+ mlir::Value createIntExpr (const Expr *intExpr) {
56+ mlir::Value expr = cgf.emitScalarExpr (intExpr);
57+ mlir::Location exprLoc = cgf.cgm .getLoc (intExpr->getBeginLoc ());
58+
59+ mlir::IntegerType targetType = mlir::IntegerType::get (
60+ &cgf.getMLIRContext (), cgf.getContext ().getIntWidth (intExpr->getType ()),
61+ intExpr->getType ()->isSignedIntegerOrEnumerationType ()
62+ ? mlir::IntegerType::SignednessSemantics::Signed
63+ : mlir::IntegerType::SignednessSemantics::Unsigned);
64+
65+ auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
66+ exprLoc, targetType, expr);
67+ return conversionOp.getResult (0 );
68+ }
69+
5370 // 'condition' as an OpenACC grammar production is used for 'if' and (some
5471 // variants of) 'self'. It needs to be emitted as a signless-1-bit value, so
5572 // this function emits the expression, then sets the unrealized conversion
@@ -109,14 +126,15 @@ class OpenACCClauseCIREmitter final
109126 }
110127
111128 void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
129+ lastDeviceTypeClause = &clause;
112130 if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
113131 llvm::SmallVector<mlir::Attribute> deviceTypes;
114132 std::optional<mlir::ArrayAttr> existingDeviceTypes =
115133 operation.getDeviceTypes ();
116134
117135 // Ensure we keep the existing ones, and in the correct 'new' order.
118136 if (existingDeviceTypes) {
119- for (const mlir::Attribute & Attr : *existingDeviceTypes)
137+ for (mlir::Attribute Attr : *existingDeviceTypes)
120138 deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
121139 builder.getContext (),
122140 cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
@@ -136,6 +154,51 @@ class OpenACCClauseCIREmitter final
136154 if (!clause.getArchitectures ().empty ())
137155 operation.setDeviceType (
138156 decodeDeviceType (clause.getArchitectures ()[0 ].getIdentifierInfo ()));
157+ } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp>) {
158+ // Nothing to do here, these constructs don't have any IR for these, as
159+ // they just modify the other clauses IR. So setting of `lastDeviceType`
160+ // (done above) is all we need.
161+ } else {
162+ return clauseNotImplemented (clause);
163+ }
164+ }
165+
166+ void VisitNumWorkersClause (const OpenACCNumWorkersClause &clause) {
167+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
168+ // Collect the 'existing' device-type attributes so we can re-create them
169+ // and insert them.
170+ llvm::SmallVector<mlir::Attribute> deviceTypes;
171+ mlir::ArrayAttr existingDeviceTypes =
172+ operation.getNumWorkersDeviceTypeAttr ();
173+
174+ if (existingDeviceTypes) {
175+ for (mlir::Attribute Attr : existingDeviceTypes)
176+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
177+ builder.getContext (),
178+ cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
179+ }
180+
181+ // Insert 1 version of the 'int-expr' to the NumWorkers list per-current
182+ // device type.
183+ mlir::Value intExpr = createIntExpr (clause.getIntExpr ());
184+ if (lastDeviceTypeClause) {
185+ for (const DeviceTypeArgument &arg :
186+ lastDeviceTypeClause->getArchitectures ()) {
187+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
188+ builder.getContext (), decodeDeviceType (arg.getIdentifierInfo ())));
189+ operation.getNumWorkersMutable ().append (intExpr);
190+ }
191+ } else {
192+ // Else, we just add a single for 'none'.
193+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
194+ builder.getContext (), mlir::acc::DeviceType::None));
195+ operation.getNumWorkersMutable ().append (intExpr);
196+ }
197+
198+ operation.setNumWorkersDeviceTypeAttr (
199+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
200+ } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
201+ llvm_unreachable (" num_workers not valid on serial" );
139202 } else {
140203 return clauseNotImplemented (clause);
141204 }
0 commit comments