@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
171171 builder, loc, registerFuncOp.getArgument (0 ), asFortranDesc, bounds,
172172 /* structured=*/ false , /* implicit=*/ true ,
173173 mlir::acc::DataClause::acc_update_device, descTy);
174- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
174+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
175175 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
176176 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
177177
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
245245 builder, loc, loadOp, asFortran, bounds,
246246 /* structured=*/ false , /* implicit=*/ true ,
247247 mlir::acc::DataClause::acc_update_device, loadOp.getType ());
248- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
248+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
249249 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
250250 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
251251 modBuilder.setInsertionPointAfter (postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
15591559 }
15601560}
15611561
1562- static void
1563- genWaitClause ( Fortran::lower::AbstractConverter &converter,
1564- const Fortran::parser::AccClause::Wait *waitClause,
1565- llvm::SmallVector<mlir::Value> &waitOperands,
1566- llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1567- llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1568- llvm::SmallVector<int32_t > &waitOperandsSegments ,
1569- mlir::Value &waitDevnum ,
1570- llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1571- Fortran::lower::StatementContext &stmtCtx) {
1562+ static void genWaitClauseWithDeviceType (
1563+ Fortran::lower::AbstractConverter &converter,
1564+ const Fortran::parser::AccClause::Wait *waitClause,
1565+ llvm::SmallVector<mlir::Value> &waitOperands,
1566+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1567+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1568+ llvm::SmallVector<bool > &hasDevnums ,
1569+ llvm::SmallVector< int32_t > &waitOperandsSegments ,
1570+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1571+ Fortran::lower::StatementContext &stmtCtx) {
15721572 const auto &waitClauseValue = waitClause->v ;
15731573 if (waitClauseValue) { // wait has a value.
1574+ llvm::SmallVector<mlir::Value> waitValues;
1575+
15741576 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1577+ const auto &waitDevnumValue =
1578+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1579+ bool hasDevnum = false ;
1580+ if (waitDevnumValue) {
1581+ waitValues.push_back (fir::getBase (converter.genExprValue (
1582+ *Fortran::semantics::GetExpr (*waitDevnumValue), stmtCtx)));
1583+ hasDevnum = true ;
1584+ }
1585+
15751586 const auto &waitList =
15761587 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1577- llvm::SmallVector<mlir::Value> waitValues;
15781588 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
15791589 waitValues.push_back (fir::getBase (converter.genExprValue (
15801590 *Fortran::semantics::GetExpr (value), stmtCtx)));
15811591 }
1592+
15821593 for (auto deviceTypeAttr : deviceTypeAttrs) {
15831594 for (auto value : waitValues)
15841595 waitOperands.push_back (value);
15851596 waitOperandsDeviceTypes.push_back (deviceTypeAttr);
15861597 waitOperandsSegments.push_back (waitValues.size ());
1598+ hasDevnums.push_back (hasDevnum);
15871599 }
1588-
1589- // TODO: move to device_type model.
1590- const auto &waitDevnumValue =
1591- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1592- if (waitDevnumValue)
1593- waitDevnum = fir::getBase (converter.genExprValue (
1594- *Fortran::semantics::GetExpr (*waitDevnumValue), stmtCtx));
15951600 } else {
15961601 for (auto deviceTypeAttr : deviceTypeAttrs)
15971602 waitOnlyDeviceTypes.push_back (deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20932098 vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
20942099 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
20952100 llvm::SmallVector<int32_t > numGangsSegments, waitOperandsSegments;
2101+ llvm::SmallVector<bool > hasWaitDevnums;
20962102
20972103 llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
20982104 firstprivateOperands;
20992105 llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
21002106 reductionRecipes;
2101- mlir::Value waitDevnum; // TODO not yet implemented on compute op.
21022107
21032108 // Self clause has optional values but can be present with
21042109 // no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
21282133 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
21292134 } else if (const auto *waitClause =
21302135 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
2131- genWaitClause (converter, waitClause, waitOperands,
2132- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2133- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
2136+ genWaitClauseWithDeviceType (converter, waitClause, waitOperands,
2137+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2138+ hasWaitDevnums, waitOperandsSegments,
2139+ crtDeviceTypes, stmtCtx);
21342140 } else if (const auto *numGangsClause =
21352141 std::get_if<Fortran::parser::AccClause::NumGangs>(
21362142 &clause.u )) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
23722378 builder.getDenseI32ArrayAttr (numGangsSegments));
23732379 }
23742380 if (!asyncDeviceTypes.empty ())
2375- computeOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2381+ computeOp.setAsyncOperandsDeviceTypeAttr (
2382+ builder.getArrayAttr (asyncDeviceTypes));
23762383 if (!asyncOnlyDeviceTypes.empty ())
23772384 computeOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
23782385
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
23822389 if (!waitOperandsSegments.empty ())
23832390 computeOp.setWaitOperandsSegmentsAttr (
23842391 builder.getDenseI32ArrayAttr (waitOperandsSegments));
2392+ if (!hasWaitDevnums.empty ())
2393+ computeOp.setHasWaitDevnumAttr (builder.getBoolArrayAttr (hasWaitDevnums));
23852394 if (!waitOnlyDeviceTypes.empty ())
23862395 computeOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
23872396
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
24272436 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
24282437 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
24292438 llvm::SmallVector<int32_t > waitOperandsSegments;
2439+ llvm::SmallVector<bool > hasWaitDevnums;
24302440
24312441 bool hasDefaultNone = false ;
24322442 bool hasDefaultPresent = false ;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
25232533 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
25242534 } else if (const auto *waitClause =
25252535 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
2526- genWaitClause (converter, waitClause, waitOperands,
2527- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2528- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
2536+ genWaitClauseWithDeviceType (converter, waitClause, waitOperands,
2537+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2538+ hasWaitDevnums, waitOperandsSegments,
2539+ crtDeviceTypes, stmtCtx);
25292540 } else if (const auto *defaultClause =
25302541 std::get_if<Fortran::parser::AccClause::Default>(&clause.u )) {
25312542 if ((defaultClause->v ).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
25452556 llvm::SmallVector<int32_t > operandSegments;
25462557 addOperand (operands, operandSegments, ifCond);
25472558 addOperands (operands, operandSegments, async);
2548- addOperand (operands, operandSegments, waitDevnum);
25492559 addOperands (operands, operandSegments, waitOperands);
25502560 addOperands (operands, operandSegments, dataClauseOperands);
25512561
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
25572567 operandSegments);
25582568
25592569 if (!asyncDeviceTypes.empty ())
2560- dataOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2570+ dataOp.setAsyncOperandsDeviceTypeAttr (
2571+ builder.getArrayAttr (asyncDeviceTypes));
25612572 if (!asyncOnlyDeviceTypes.empty ())
25622573 dataOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
25632574 if (!waitOperandsDeviceTypes.empty ())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
25662577 if (!waitOperandsSegments.empty ())
25672578 dataOp.setWaitOperandsSegmentsAttr (
25682579 builder.getDenseI32ArrayAttr (waitOperandsSegments));
2580+ if (!hasWaitDevnums.empty ())
2581+ dataOp.setHasWaitDevnumAttr (builder.getBoolArrayAttr (hasWaitDevnums));
25692582 if (!waitOnlyDeviceTypes.empty ())
25702583 dataOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
25712584
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
30073020 return attributes.empty () ? nullptr : b.getArrayAttr (attributes);
30083021}
30093022
3023+ static inline mlir::ArrayAttr
3024+ getBoolArrayAttr (fir::FirOpBuilder &b, llvm::SmallVector<bool > &values) {
3025+ return values.empty () ? nullptr : b.getBoolArrayAttr (values);
3026+ }
3027+
30103028static inline mlir::DenseI32ArrayAttr
30113029getDenseI32ArrayAttr (fir::FirOpBuilder &builder,
30123030 llvm::SmallVector<int32_t > &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30243042 waitOperands, deviceTypeOperands, asyncOperands;
30253043 llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
30263044 asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
3045+ llvm::SmallVector<bool > hasWaitDevnums;
30273046 llvm::SmallVector<int32_t > waitOperandsSegments;
30283047
30293048 fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30513070 crtDeviceTypes, stmtCtx);
30523071 } else if (const auto *waitClause =
30533072 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
3054- genWaitClause (converter, waitClause, waitOperands,
3055- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3056- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
3073+ genWaitClauseWithDeviceType (converter, waitClause, waitOperands,
3074+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3075+ hasWaitDevnums, waitOperandsSegments,
3076+ crtDeviceTypes, stmtCtx);
30573077 } else if (const auto *deviceTypeClause =
30583078 std::get_if<Fortran::parser::AccClause::DeviceType>(
30593079 &clause.u )) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30923112 builder.create <mlir::acc::UpdateOp>(
30933113 currentLocation, ifCond, asyncOperands,
30943114 getArrayAttr (builder, asyncOperandsDeviceTypes),
3095- getArrayAttr (builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
3115+ getArrayAttr (builder, asyncOnlyDeviceTypes), waitOperands,
30963116 getDenseI32ArrayAttr (builder, waitOperandsSegments),
30973117 getArrayAttr (builder, waitOperandsDeviceTypes),
3118+ getBoolArrayAttr (builder, hasWaitDevnums),
30983119 getArrayAttr (builder, waitOnlyDeviceTypes), dataClauseOperands,
30993120 ifPresent);
31003121
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
32683289 builder, loc, addrOp, asFortranDesc, bounds,
32693290 /* structured=*/ false , /* implicit=*/ true ,
32703291 mlir::acc::DataClause::acc_update_device, addrOp.getType ());
3271- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
3292+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
32723293 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
32733294 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
32743295
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
33493370 builder, loc, addrOp, asFortran, bounds,
33503371 /* structured=*/ false , /* implicit=*/ true ,
33513372 mlir::acc::DataClause::acc_update_device, addrOp.getType ());
3352- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
3373+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
33533374 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
33543375 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
33553376 modBuilder.setInsertionPointAfter (postDeallocOp);
0 commit comments