@@ -317,10 +317,18 @@ class OpenACCClauseCIREmitter final
317317 operation.getAsyncOperandsDeviceTypeAttr (),
318318 createIntExpr (clause.getIntExpr ()), range));
319319 }
320+ } else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
321+ // Wait doesn't have a device_type, so its handling here is slightly
322+ // different.
323+ if (!clause.hasIntExpr ())
324+ operation.setAsync (true );
325+ else
326+ operation.getAsyncOperandMutable ().append (
327+ createIntExpr (clause.getIntExpr ()));
320328 } else {
321329 // TODO: When we've implemented this for everything, switch this to an
322330 // unreachable. Combined constructs remain. Data, enter data, exit data,
323- // update, wait, combined constructs remain.
331+ // update, combined constructs remain.
324332 return clauseNotImplemented (clause);
325333 }
326334 }
@@ -345,15 +353,15 @@ class OpenACCClauseCIREmitter final
345353
346354 void VisitIfClause (const OpenACCIfClause &clause) {
347355 if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, InitOp,
348- ShutdownOp, SetOp, DataOp>) {
356+ ShutdownOp, SetOp, DataOp, WaitOp >) {
349357 operation.getIfCondMutable ().append (
350358 createCondition (clause.getConditionExpr ()));
351359 } else {
352360 // 'if' applies to most of the constructs, but hold off on lowering them
353361 // until we can write tests/know what we're doing with codegen to make
354362 // sure we get it right.
355363 // TODO: When we've implemented this for everything, switch this to an
356- // unreachable. Enter data, exit data, host_data, update, wait, combined
364+ // unreachable. Enter data, exit data, host_data, update, combined
357365 // constructs remain.
358366 return clauseNotImplemented (clause);
359367 }
@@ -444,11 +452,9 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
444452}
445453
446454template <typename Op>
447- mlir::LogicalResult CIRGenFunction::emitOpenACCOp (
455+ Op CIRGenFunction::emitOpenACCOp (
448456 mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
449457 llvm::ArrayRef<const OpenACCClause *> clauses) {
450- mlir::LogicalResult res = mlir::success ();
451-
452458 llvm::SmallVector<mlir::Type> retTy;
453459 llvm::SmallVector<mlir::Value> operands;
454460 auto op = builder.create <Op>(start, retTy, operands);
@@ -461,7 +467,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
461467 makeClauseEmitter (op, *this , builder, dirKind, dirLoc)
462468 .VisitClauseList (clauses);
463469 }
464- return res ;
470+ return op ;
465471}
466472
467473mlir::LogicalResult
@@ -500,22 +506,61 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
500506mlir::LogicalResult
501507CIRGenFunction::emitOpenACCInitConstruct (const OpenACCInitConstruct &s) {
502508 mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
503- return emitOpenACCOp<InitOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
509+ emitOpenACCOp<InitOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
504510 s.clauses ());
511+ return mlir::success ();
505512}
506513
507514mlir::LogicalResult
508515CIRGenFunction::emitOpenACCSetConstruct (const OpenACCSetConstruct &s) {
509516 mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
510- return emitOpenACCOp<SetOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
517+ emitOpenACCOp<SetOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
511518 s.clauses ());
519+ return mlir::success ();
512520}
513521
514522mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct (
515523 const OpenACCShutdownConstruct &s) {
516524 mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
517- return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind (),
525+ emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind (),
518526 s.getDirectiveLoc (), s.clauses ());
527+ return mlir::success ();
528+ }
529+
530+ mlir::LogicalResult
531+ CIRGenFunction::emitOpenACCWaitConstruct (const OpenACCWaitConstruct &s) {
532+ mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
533+ auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind (),
534+ s.getDirectiveLoc (), s.clauses ());
535+
536+ auto createIntExpr = [this ](const Expr *intExpr) {
537+ mlir::Value expr = emitScalarExpr (intExpr);
538+ mlir::Location exprLoc = cgm.getLoc (intExpr->getBeginLoc ());
539+
540+ mlir::IntegerType targetType = mlir::IntegerType::get (
541+ &getMLIRContext (), getContext ().getIntWidth (intExpr->getType ()),
542+ intExpr->getType ()->isSignedIntegerOrEnumerationType ()
543+ ? mlir::IntegerType::SignednessSemantics::Signed
544+ : mlir::IntegerType::SignednessSemantics::Unsigned);
545+
546+ auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
547+ exprLoc, targetType, expr);
548+ return conversionOp.getResult (0 );
549+ };
550+
551+ // Emit the correct 'wait' clauses.
552+ {
553+ mlir::OpBuilder::InsertionGuard guardCase (builder);
554+ builder.setInsertionPoint (waitOp);
555+
556+ if (s.hasDevNumExpr ())
557+ waitOp.getWaitDevnumMutable ().append (createIntExpr (s.getDevNumExpr ()));
558+
559+ for (Expr *QueueExpr : s.getQueueIdExprs ())
560+ waitOp.getWaitOperandsMutable ().append (createIntExpr (QueueExpr));
561+ }
562+
563+ return mlir::success ();
519564}
520565
521566mlir::LogicalResult
@@ -544,11 +589,6 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
544589 return mlir::failure ();
545590}
546591mlir::LogicalResult
547- CIRGenFunction::emitOpenACCWaitConstruct (const OpenACCWaitConstruct &s) {
548- cgm.errorNYI (s.getSourceRange (), " OpenACC Wait Construct" );
549- return mlir::failure ();
550- }
551- mlir::LogicalResult
552592CIRGenFunction::emitOpenACCUpdateConstruct (const OpenACCUpdateConstruct &s) {
553593 cgm.errorNYI (s.getSourceRange (), " OpenACC Update Construct" );
554594 return mlir::failure ();
0 commit comments