@@ -4396,10 +4396,35 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
43964396 return std::nullopt ;
43974397}
43984398
4399+ // Helper function to extract string value from bind name variant
4400+ static std::optional<llvm::StringRef> getBindNameStringValue (
4401+ const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4402+ &bindNameValue) {
4403+ if (!bindNameValue.has_value ()) {
4404+ return std::nullopt ;
4405+ }
4406+
4407+ return std::visit (
4408+ [](const auto &attr) -> std::optional<llvm::StringRef> {
4409+ if constexpr (std::is_same_v<std::decay_t <decltype (attr)>,
4410+ mlir::StringAttr>) {
4411+ return attr.getValue ();
4412+ } else if constexpr (std::is_same_v<std::decay_t <decltype (attr)>,
4413+ mlir::SymbolRefAttr>) {
4414+ return attr.getLeafReference ();
4415+ } else {
4416+ return std::nullopt ;
4417+ }
4418+ },
4419+ bindNameValue.value ());
4420+ }
4421+
43994422static bool compareDeviceTypeInfo (
44004423 mlir::acc::RoutineOp op,
4401- llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4402- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4424+ llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
4425+ llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
4426+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
4427+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
44034428 llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
44044429 llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
44054430 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4409,9 +4434,13 @@ static bool compareDeviceTypeInfo(
44094434 for (uint32_t dtypeInt = 0 ;
44104435 dtypeInt != mlir::acc::getMaxEnumValForDeviceType (); ++dtypeInt) {
44114436 auto dtype = static_cast <mlir::acc::DeviceType>(dtypeInt);
4412- if (op.getBindNameValue (dtype) !=
4413- getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4414- bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4437+ auto bindNameValue = getBindNameStringValue (op.getBindNameValue (dtype));
4438+ if (bindNameValue !=
4439+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4440+ bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
4441+ bindNameValue !=
4442+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4443+ bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
44154444 return false ;
44164445 if (op.hasGang (dtype) != hasDeviceType (gangArrayAttr, dtype))
44174446 return false ;
@@ -4458,8 +4487,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
44584487void createOpenACCRoutineConstruct (
44594488 Fortran::lower::AbstractConverter &converter, mlir::Location loc,
44604489 mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4461- bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4462- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4490+ bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
4491+ llvm::SmallVector<mlir::Attribute> &bindStrNames,
4492+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4493+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
44634494 llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
44644495 llvm::SmallVector<mlir::Attribute> &gangDimValues,
44654496 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4472,7 +4503,8 @@ void createOpenACCRoutineConstruct(
44724503 0 ) {
44734504 // If the routine is already specified with the same clauses, just skip
44744505 // the operation creation.
4475- if (compareDeviceTypeInfo (routineOp, bindNames, bindNameDeviceTypes,
4506+ if (compareDeviceTypeInfo (routineOp, bindIdNames, bindStrNames,
4507+ bindIdNameDeviceTypes, bindStrNameDeviceTypes,
44764508 gangDeviceTypes, gangDimValues,
44774509 gangDimDeviceTypes, seqDeviceTypes,
44784510 workerDeviceTypes, vectorDeviceTypes) &&
@@ -4489,8 +4521,10 @@ void createOpenACCRoutineConstruct(
44894521 modBuilder.create <mlir::acc::RoutineOp>(
44904522 loc, routineOpStr,
44914523 mlir::SymbolRefAttr::get (builder.getContext (), funcName),
4492- getArrayAttrOrNull (builder, bindNames),
4493- getArrayAttrOrNull (builder, bindNameDeviceTypes),
4524+ getArrayAttrOrNull (builder, bindIdNames),
4525+ getArrayAttrOrNull (builder, bindStrNames),
4526+ getArrayAttrOrNull (builder, bindIdNameDeviceTypes),
4527+ getArrayAttrOrNull (builder, bindStrNameDeviceTypes),
44944528 getArrayAttrOrNull (builder, workerDeviceTypes),
44954529 getArrayAttrOrNull (builder, vectorDeviceTypes),
44964530 getArrayAttrOrNull (builder, seqDeviceTypes), hasNohost,
@@ -4507,8 +4541,10 @@ static void interpretRoutineDeviceInfo(
45074541 llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
45084542 llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
45094543 llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4510- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4511- llvm::SmallVector<mlir::Attribute> &bindNames,
4544+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4545+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4546+ llvm::SmallVector<mlir::Attribute> &bindIdNames,
4547+ llvm::SmallVector<mlir::Attribute> &bindStrNames,
45124548 llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
45134549 llvm::SmallVector<mlir::Attribute> &gangDimValues,
45144550 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4541,16 +4577,18 @@ static void interpretRoutineDeviceInfo(
45414577 if (dinfo.bindNameOpt ().has_value ()) {
45424578 const auto &bindName = dinfo.bindNameOpt ().value ();
45434579 mlir::Attribute bindNameAttr;
4544- if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4580+ if (const auto &bindSym{
4581+ std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4582+ bindNameAttr = builder.getSymbolRefAttr (converter.mangleName (*bindSym));
4583+ bindIdNames.push_back (bindNameAttr);
4584+ bindIdNameDeviceTypes.push_back (getDeviceTypeAttr ());
4585+ } else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
45454586 bindNameAttr = builder.getStringAttr (*bindStr);
4546- } else if (const auto &bindSym{
4547- std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4548- bindNameAttr = builder.getStringAttr (converter.mangleName (*bindSym));
4587+ bindStrNames.push_back (bindNameAttr);
4588+ bindStrNameDeviceTypes.push_back (getDeviceTypeAttr ());
45494589 } else {
45504590 llvm_unreachable (" Unsupported bind name type" );
45514591 }
4552- bindNames.push_back (bindNameAttr);
4553- bindNameDeviceTypes.push_back (getDeviceTypeAttr ());
45544592 }
45554593}
45564594
@@ -4566,8 +4604,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45664604 bool hasNohost{false };
45674605
45684606 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4569- workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4570- gangDimDeviceTypes, gangDimValues;
4607+ workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4608+ bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
4609+ gangDimValues;
45714610
45724611 for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
45734612 // Device Independent Attributes
@@ -4576,24 +4615,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45764615 }
45774616 // Note: Device Independent Attributes are set to the
45784617 // none device type in `info`.
4579- interpretRoutineDeviceInfo (converter, info, seqDeviceTypes,
4580- vectorDeviceTypes, workerDeviceTypes,
4581- bindNameDeviceTypes, bindNames, gangDeviceTypes ,
4582- gangDimValues, gangDimDeviceTypes);
4618+ interpretRoutineDeviceInfo (
4619+ converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
4620+ bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames ,
4621+ bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
45834622
45844623 // Device Dependent Attributes
45854624 for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
45864625 info.deviceTypeInfos ()) {
4587- interpretRoutineDeviceInfo (
4588- converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4589- workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4590- gangDimValues, gangDimDeviceTypes);
4626+ interpretRoutineDeviceInfo (converter, dinfo, seqDeviceTypes,
4627+ vectorDeviceTypes, workerDeviceTypes,
4628+ bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4629+ bindIdNames, bindStrNames, gangDeviceTypes,
4630+ gangDimValues, gangDimDeviceTypes);
45914631 }
45924632 }
45934633 createOpenACCRoutineConstruct (
4594- converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4595- bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4596- seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4634+ converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
4635+ bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4636+ gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
4637+ workerDeviceTypes, vectorDeviceTypes);
45974638}
45984639
45994640static void
0 commit comments