@@ -150,7 +150,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
150
150
builder, loc, registerFuncOp.getArgument (0 ), asFortranDesc, bounds,
151
151
/* structured=*/ false , /* implicit=*/ true ,
152
152
mlir::acc::DataClause::acc_update_device, descTy);
153
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 0 , 1 };
153
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
154
154
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
155
155
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
156
156
@@ -219,7 +219,7 @@ static void createDeclareDeallocFuncWithArg(
219
219
builder, loc, loadOp, asFortran, bounds,
220
220
/* structured=*/ false , /* implicit=*/ true ,
221
221
mlir::acc::DataClause::acc_update_device, loadOp.getType ());
222
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 0 , 1 };
222
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
223
223
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
224
224
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
225
225
modBuilder.setInsertionPointAfter (postDeallocOp);
@@ -1416,27 +1416,35 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
1416
1416
}
1417
1417
}
1418
1418
1419
- static void genDeviceTypeClause (
1420
- Fortran::lower::AbstractConverter &converter, mlir::Location clauseLocation,
1419
+ static mlir::acc::DeviceType
1420
+ getDeviceType (Fortran::parser::AccDeviceTypeExpr::Device device) {
1421
+ switch (device) {
1422
+ case Fortran::parser::AccDeviceTypeExpr::Device::Star:
1423
+ return mlir::acc::DeviceType::Star;
1424
+ case Fortran::parser::AccDeviceTypeExpr::Device::Default:
1425
+ return mlir::acc::DeviceType::Default;
1426
+ case Fortran::parser::AccDeviceTypeExpr::Device::Nvidia:
1427
+ return mlir::acc::DeviceType::Nvidia;
1428
+ case Fortran::parser::AccDeviceTypeExpr::Device::Radeon:
1429
+ return mlir::acc::DeviceType::Radeon;
1430
+ case Fortran::parser::AccDeviceTypeExpr::Device::Host:
1431
+ return mlir::acc::DeviceType::Host;
1432
+ case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
1433
+ return mlir::acc::DeviceType::Multicore;
1434
+ }
1435
+ return mlir::acc::DeviceType::Default;
1436
+ }
1437
+
1438
+ static void gatherDeviceTypeAttrs (
1439
+ fir::FirOpBuilder &builder, mlir::Location clauseLocation,
1421
1440
const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
1422
- llvm::SmallVectorImpl <mlir::Value > &operands ,
1441
+ llvm::SmallVector <mlir::Attribute > &deviceTypes ,
1423
1442
Fortran::lower::StatementContext &stmtCtx) {
1424
1443
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1425
1444
deviceTypeClause->v ;
1426
- for (const auto &deviceTypeExpr : deviceTypeExprList.v ) {
1427
- const auto &expr = std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
1428
- deviceTypeExpr.t );
1429
- if (expr) {
1430
- operands.push_back (fir::getBase (converter.genExprValue (
1431
- *Fortran::semantics::GetExpr (expr), stmtCtx, &clauseLocation)));
1432
- } else {
1433
- // * was passed as value and will be represented as a special constant.
1434
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1435
- mlir::Value star = firOpBuilder.createIntegerConstant (
1436
- clauseLocation, firOpBuilder.getIndexType (), starCst);
1437
- operands.push_back (star);
1438
- }
1439
- }
1445
+ for (const auto &deviceTypeExpr : deviceTypeExprList.v )
1446
+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
1447
+ builder.getContext (), getDeviceType (deviceTypeExpr.v )));
1440
1448
}
1441
1449
1442
1450
static void genIfClause (Fortran::lower::AbstractConverter &converter,
@@ -2443,10 +2451,10 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
2443
2451
mlir::Location currentLocation,
2444
2452
const Fortran::parser::AccClauseList &accClauseList) {
2445
2453
mlir::Value ifCond, deviceNum;
2446
- llvm::SmallVector<mlir::Value> deviceTypeOperands;
2447
2454
2448
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2455
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
2449
2456
Fortran::lower::StatementContext stmtCtx;
2457
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
2450
2458
2451
2459
// Lower clauses values mapped to operands.
2452
2460
// Keep track of each group of operands separately as clauses can appear
@@ -2464,19 +2472,23 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
2464
2472
} else if (const auto *deviceTypeClause =
2465
2473
std::get_if<Fortran::parser::AccClause::DeviceType>(
2466
2474
&clause.u )) {
2467
- genDeviceTypeClause (converter , clauseLocation, deviceTypeClause,
2468
- deviceTypeOperands , stmtCtx);
2475
+ gatherDeviceTypeAttrs (builder , clauseLocation, deviceTypeClause,
2476
+ deviceTypes , stmtCtx);
2469
2477
}
2470
2478
}
2471
2479
2472
2480
// Prepare the operand segment size attribute and the operands value range.
2473
2481
llvm::SmallVector<mlir::Value, 6 > operands;
2474
- llvm::SmallVector<int32_t , 3 > operandSegments;
2475
- addOperands (operands, operandSegments, deviceTypeOperands);
2482
+ llvm::SmallVector<int32_t , 2 > operandSegments;
2483
+
2476
2484
addOperand (operands, operandSegments, deviceNum);
2477
2485
addOperand (operands, operandSegments, ifCond);
2478
2486
2479
- createSimpleOp<Op>(firOpBuilder, currentLocation, operands, operandSegments);
2487
+ Op op =
2488
+ createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
2489
+ if (!deviceTypes.empty ())
2490
+ op.setDeviceTypesAttr (
2491
+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
2480
2492
}
2481
2493
2482
2494
void genACCSetOp (Fortran::lower::AbstractConverter &converter,
@@ -2485,8 +2497,9 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
2485
2497
mlir::Value ifCond, deviceNum, defaultAsync;
2486
2498
llvm::SmallVector<mlir::Value> deviceTypeOperands;
2487
2499
2488
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2500
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
2489
2501
Fortran::lower::StatementContext stmtCtx;
2502
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
2490
2503
2491
2504
// Lower clauses values mapped to operands.
2492
2505
// Keep track of each group of operands separately as clauses can appear
@@ -2509,21 +2522,24 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
2509
2522
} else if (const auto *deviceTypeClause =
2510
2523
std::get_if<Fortran::parser::AccClause::DeviceType>(
2511
2524
&clause.u )) {
2512
- genDeviceTypeClause (converter , clauseLocation, deviceTypeClause,
2513
- deviceTypeOperands , stmtCtx);
2525
+ gatherDeviceTypeAttrs (builder , clauseLocation, deviceTypeClause,
2526
+ deviceTypes , stmtCtx);
2514
2527
}
2515
2528
}
2516
2529
2517
2530
// Prepare the operand segment size attribute and the operands value range.
2518
2531
llvm::SmallVector<mlir::Value> operands;
2519
- llvm::SmallVector<int32_t , 4 > operandSegments;
2520
- addOperands (operands, operandSegments, deviceTypeOperands);
2532
+ llvm::SmallVector<int32_t , 3 > operandSegments;
2521
2533
addOperand (operands, operandSegments, defaultAsync);
2522
2534
addOperand (operands, operandSegments, deviceNum);
2523
2535
addOperand (operands, operandSegments, ifCond);
2524
2536
2525
- createSimpleOp<mlir::acc::SetOp>(firOpBuilder, currentLocation, operands,
2526
- operandSegments);
2537
+ auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
2538
+ operandSegments);
2539
+ if (!deviceTypes.empty ()) {
2540
+ assert (deviceTypes.size () == 1 && " expect only one value for acc.set" );
2541
+ op.setDeviceTypeAttr (mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0 ]));
2542
+ }
2527
2543
}
2528
2544
2529
2545
static void
@@ -2535,6 +2551,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
2535
2551
mlir::Value ifCond, async, waitDevnum;
2536
2552
llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
2537
2553
waitOperands, deviceTypeOperands;
2554
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
2538
2555
2539
2556
// Async and wait clause have optional values but can be present with
2540
2557
// no value as well. When there is no value, the op has an attribute to
@@ -2563,8 +2580,8 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
2563
2580
} else if (const auto *deviceTypeClause =
2564
2581
std::get_if<Fortran::parser::AccClause::DeviceType>(
2565
2582
&clause.u )) {
2566
- genDeviceTypeClause (converter , clauseLocation, deviceTypeClause,
2567
- deviceTypeOperands , stmtCtx);
2583
+ gatherDeviceTypeAttrs (builder , clauseLocation, deviceTypeClause,
2584
+ deviceTypes , stmtCtx);
2568
2585
} else if (const auto *hostClause =
2569
2586
std::get_if<Fortran::parser::AccClause::Host>(&clause.u )) {
2570
2587
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2602,11 +2619,13 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
2602
2619
addOperand (operands, operandSegments, async);
2603
2620
addOperand (operands, operandSegments, waitDevnum);
2604
2621
addOperands (operands, operandSegments, waitOperands);
2605
- addOperands (operands, operandSegments, deviceTypeOperands);
2606
2622
addOperands (operands, operandSegments, dataClauseOperands);
2607
2623
2608
2624
mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
2609
2625
builder, currentLocation, operands, operandSegments);
2626
+ if (!deviceTypes.empty ())
2627
+ updateOp.setDeviceTypesAttr (
2628
+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
2610
2629
2611
2630
genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
2612
2631
builder, updateHostOperands, /* structured=*/ false );
@@ -2787,7 +2806,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
2787
2806
builder, loc, addrOp, asFortranDesc, bounds,
2788
2807
/* structured=*/ false , /* implicit=*/ true ,
2789
2808
mlir::acc::DataClause::acc_update_device, addrOp.getType ());
2790
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 0 , 1 };
2809
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
2791
2810
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
2792
2811
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
2793
2812
@@ -2863,7 +2882,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
2863
2882
builder, loc, addrOp, asFortran, bounds,
2864
2883
/* structured=*/ false , /* implicit=*/ true ,
2865
2884
mlir::acc::DataClause::acc_update_device, addrOp.getType ());
2866
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 0 , 1 };
2885
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
2867
2886
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
2868
2887
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
2869
2888
modBuilder.setInsertionPointAfter (postDeallocOp);
0 commit comments