@@ -4414,10 +4414,34 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
4414
4414
return std::nullopt;
4415
4415
}
4416
4416
4417
+ // Helper function to extract string value from bind name variant
4418
+ static std::optional<llvm::StringRef> getBindNameStringValue (
4419
+ const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4420
+ &bindNameValue) {
4421
+ if (!bindNameValue.has_value ())
4422
+ return std::nullopt;
4423
+
4424
+ return std::visit (
4425
+ [](const auto &attr) -> std::optional<llvm::StringRef> {
4426
+ if constexpr (std::is_same_v<std::decay_t <decltype (attr)>,
4427
+ mlir::StringAttr>) {
4428
+ return attr.getValue ();
4429
+ } else if constexpr (std::is_same_v<std::decay_t <decltype (attr)>,
4430
+ mlir::SymbolRefAttr>) {
4431
+ return attr.getLeafReference ();
4432
+ } else {
4433
+ return std::nullopt;
4434
+ }
4435
+ },
4436
+ bindNameValue.value ());
4437
+ }
4438
+
4417
4439
static bool compareDeviceTypeInfo (
4418
4440
mlir::acc::RoutineOp op,
4419
- llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4420
- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4441
+ llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
4442
+ llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
4443
+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
4444
+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
4421
4445
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
4422
4446
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
4423
4447
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4427,9 +4451,13 @@ static bool compareDeviceTypeInfo(
4427
4451
for (uint32_t dtypeInt = 0 ;
4428
4452
dtypeInt != mlir::acc::getMaxEnumValForDeviceType (); ++dtypeInt) {
4429
4453
auto dtype = static_cast <mlir::acc::DeviceType>(dtypeInt);
4430
- if (op.getBindNameValue (dtype) !=
4431
- getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4432
- bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4454
+ auto bindNameValue = getBindNameStringValue (op.getBindNameValue (dtype));
4455
+ if (bindNameValue !=
4456
+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4457
+ bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
4458
+ bindNameValue !=
4459
+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4460
+ bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
4433
4461
return false ;
4434
4462
if (op.hasGang (dtype) != hasDeviceType (gangArrayAttr, dtype))
4435
4463
return false ;
@@ -4476,8 +4504,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
4476
4504
void createOpenACCRoutineConstruct (
4477
4505
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
4478
4506
mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4479
- bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4480
- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4507
+ bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
4508
+ llvm::SmallVector<mlir::Attribute> &bindStrNames,
4509
+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4510
+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4481
4511
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
4482
4512
llvm::SmallVector<mlir::Attribute> &gangDimValues,
4483
4513
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4490,7 +4520,8 @@ void createOpenACCRoutineConstruct(
4490
4520
0 ) {
4491
4521
// If the routine is already specified with the same clauses, just skip
4492
4522
// the operation creation.
4493
- if (compareDeviceTypeInfo (routineOp, bindNames, bindNameDeviceTypes,
4523
+ if (compareDeviceTypeInfo (routineOp, bindIdNames, bindStrNames,
4524
+ bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4494
4525
gangDeviceTypes, gangDimValues,
4495
4526
gangDimDeviceTypes, seqDeviceTypes,
4496
4527
workerDeviceTypes, vectorDeviceTypes) &&
@@ -4507,8 +4538,10 @@ void createOpenACCRoutineConstruct(
4507
4538
modBuilder.create <mlir::acc::RoutineOp>(
4508
4539
loc, routineOpStr,
4509
4540
mlir::SymbolRefAttr::get (builder.getContext (), funcName),
4510
- getArrayAttrOrNull (builder, bindNames),
4511
- getArrayAttrOrNull (builder, bindNameDeviceTypes),
4541
+ getArrayAttrOrNull (builder, bindIdNames),
4542
+ getArrayAttrOrNull (builder, bindStrNames),
4543
+ getArrayAttrOrNull (builder, bindIdNameDeviceTypes),
4544
+ getArrayAttrOrNull (builder, bindStrNameDeviceTypes),
4512
4545
getArrayAttrOrNull (builder, workerDeviceTypes),
4513
4546
getArrayAttrOrNull (builder, vectorDeviceTypes),
4514
4547
getArrayAttrOrNull (builder, seqDeviceTypes), hasNohost,
@@ -4525,8 +4558,10 @@ static void interpretRoutineDeviceInfo(
4525
4558
llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
4526
4559
llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
4527
4560
llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4528
- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4529
- llvm::SmallVector<mlir::Attribute> &bindNames,
4561
+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4562
+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4563
+ llvm::SmallVector<mlir::Attribute> &bindIdNames,
4564
+ llvm::SmallVector<mlir::Attribute> &bindStrNames,
4530
4565
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
4531
4566
llvm::SmallVector<mlir::Attribute> &gangDimValues,
4532
4567
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4559,16 +4594,18 @@ static void interpretRoutineDeviceInfo(
4559
4594
if (dinfo.bindNameOpt ().has_value ()) {
4560
4595
const auto &bindName = dinfo.bindNameOpt ().value ();
4561
4596
mlir::Attribute bindNameAttr;
4562
- if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4597
+ if (const auto &bindSym{
4598
+ std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4599
+ bindNameAttr = builder.getSymbolRefAttr (converter.mangleName (*bindSym));
4600
+ bindIdNames.push_back (bindNameAttr);
4601
+ bindIdNameDeviceTypes.push_back (getDeviceTypeAttr ());
4602
+ } else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4563
4603
bindNameAttr = builder.getStringAttr (*bindStr);
4564
- } else if (const auto &bindSym{
4565
- std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4566
- bindNameAttr = builder.getStringAttr (converter.mangleName (*bindSym));
4604
+ bindStrNames.push_back (bindNameAttr);
4605
+ bindStrNameDeviceTypes.push_back (getDeviceTypeAttr ());
4567
4606
} else {
4568
4607
llvm_unreachable (" Unsupported bind name type" );
4569
4608
}
4570
- bindNames.push_back (bindNameAttr);
4571
- bindNameDeviceTypes.push_back (getDeviceTypeAttr ());
4572
4609
}
4573
4610
}
4574
4611
@@ -4584,8 +4621,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
4584
4621
bool hasNohost{false };
4585
4622
4586
4623
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4587
- workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4588
- gangDimDeviceTypes, gangDimValues;
4624
+ workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4625
+ bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
4626
+ gangDimValues;
4589
4627
4590
4628
for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
4591
4629
// Device Independent Attributes
@@ -4594,24 +4632,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
4594
4632
}
4595
4633
// Note: Device Independent Attributes are set to the
4596
4634
// none device type in `info`.
4597
- interpretRoutineDeviceInfo (converter, info, seqDeviceTypes,
4598
- vectorDeviceTypes, workerDeviceTypes,
4599
- bindNameDeviceTypes, bindNames, gangDeviceTypes ,
4600
- gangDimValues, gangDimDeviceTypes);
4635
+ interpretRoutineDeviceInfo (
4636
+ converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
4637
+ bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames ,
4638
+ bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
4601
4639
4602
4640
// Device Dependent Attributes
4603
4641
for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
4604
4642
info.deviceTypeInfos ()) {
4605
- interpretRoutineDeviceInfo (
4606
- converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4607
- workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4608
- gangDimValues, gangDimDeviceTypes);
4643
+ interpretRoutineDeviceInfo (converter, dinfo, seqDeviceTypes,
4644
+ vectorDeviceTypes, workerDeviceTypes,
4645
+ bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4646
+ bindIdNames, bindStrNames, gangDeviceTypes,
4647
+ gangDimValues, gangDimDeviceTypes);
4609
4648
}
4610
4649
}
4611
4650
createOpenACCRoutineConstruct (
4612
- converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4613
- bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4614
- seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4651
+ converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
4652
+ bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4653
+ gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
4654
+ workerDeviceTypes, vectorDeviceTypes);
4615
4655
}
4616
4656
4617
4657
static void
0 commit comments