@@ -4414,10 +4414,34 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
44144414 return std::nullopt ;
44154415}
44164416
4417+ // Helper function to extract string value from bind name variant
4418+ static std::optional<llvm::StringRef> getBindNameStringValue (
4419+ const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4420+ &bindNameValue) {
4421+ if (!bindNameValue.has_value ())
4422+ return std::nullopt ;
4423+
4424+ return std::visit (
4425+ [](const auto &attr) -> std::optional<llvm::StringRef> {
4426+ if constexpr (std::is_same_v<std::decay_t <decltype (attr)>,
4427+ mlir::StringAttr>) {
4428+ return attr.getValue ();
4429+ } else if constexpr (std::is_same_v<std::decay_t <decltype (attr)>,
4430+ mlir::SymbolRefAttr>) {
4431+ return attr.getLeafReference ();
4432+ } else {
4433+ return std::nullopt ;
4434+ }
4435+ },
4436+ bindNameValue.value ());
4437+ }
4438+
44174439static bool compareDeviceTypeInfo (
44184440 mlir::acc::RoutineOp op,
4419- llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4420- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4441+ llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
4442+ llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
4443+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
4444+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
44214445 llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
44224446 llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
44234447 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4427,9 +4451,13 @@ static bool compareDeviceTypeInfo(
44274451 for (uint32_t dtypeInt = 0 ;
44284452 dtypeInt != mlir::acc::getMaxEnumValForDeviceType (); ++dtypeInt) {
44294453 auto dtype = static_cast <mlir::acc::DeviceType>(dtypeInt);
4430- if (op.getBindNameValue (dtype) !=
4431- getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4432- bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4454+ auto bindNameValue = getBindNameStringValue (op.getBindNameValue (dtype));
4455+ if (bindNameValue !=
4456+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4457+ bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
4458+ bindNameValue !=
4459+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4460+ bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
44334461 return false ;
44344462 if (op.hasGang (dtype) != hasDeviceType (gangArrayAttr, dtype))
44354463 return false ;
@@ -4476,8 +4504,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
44764504void createOpenACCRoutineConstruct (
44774505 Fortran::lower::AbstractConverter &converter, mlir::Location loc,
44784506 mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4479- bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4480- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4507+ bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
4508+ llvm::SmallVector<mlir::Attribute> &bindStrNames,
4509+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4510+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
44814511 llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
44824512 llvm::SmallVector<mlir::Attribute> &gangDimValues,
44834513 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4490,7 +4520,8 @@ void createOpenACCRoutineConstruct(
44904520 0 ) {
44914521 // If the routine is already specified with the same clauses, just skip
44924522 // the operation creation.
4493- if (compareDeviceTypeInfo (routineOp, bindNames, bindNameDeviceTypes,
4523+ if (compareDeviceTypeInfo (routineOp, bindIdNames, bindStrNames,
4524+ bindIdNameDeviceTypes, bindStrNameDeviceTypes,
44944525 gangDeviceTypes, gangDimValues,
44954526 gangDimDeviceTypes, seqDeviceTypes,
44964527 workerDeviceTypes, vectorDeviceTypes) &&
@@ -4507,8 +4538,10 @@ void createOpenACCRoutineConstruct(
45074538 modBuilder.create <mlir::acc::RoutineOp>(
45084539 loc, routineOpStr,
45094540 mlir::SymbolRefAttr::get (builder.getContext (), funcName),
4510- getArrayAttrOrNull (builder, bindNames),
4511- getArrayAttrOrNull (builder, bindNameDeviceTypes),
4541+ getArrayAttrOrNull (builder, bindIdNames),
4542+ getArrayAttrOrNull (builder, bindStrNames),
4543+ getArrayAttrOrNull (builder, bindIdNameDeviceTypes),
4544+ getArrayAttrOrNull (builder, bindStrNameDeviceTypes),
45124545 getArrayAttrOrNull (builder, workerDeviceTypes),
45134546 getArrayAttrOrNull (builder, vectorDeviceTypes),
45144547 getArrayAttrOrNull (builder, seqDeviceTypes), hasNohost,
@@ -4525,8 +4558,10 @@ static void interpretRoutineDeviceInfo(
45254558 llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
45264559 llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
45274560 llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4528- llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4529- llvm::SmallVector<mlir::Attribute> &bindNames,
4561+ llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4562+ llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4563+ llvm::SmallVector<mlir::Attribute> &bindIdNames,
4564+ llvm::SmallVector<mlir::Attribute> &bindStrNames,
45304565 llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
45314566 llvm::SmallVector<mlir::Attribute> &gangDimValues,
45324567 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4559,16 +4594,18 @@ static void interpretRoutineDeviceInfo(
45594594 if (dinfo.bindNameOpt ().has_value ()) {
45604595 const auto &bindName = dinfo.bindNameOpt ().value ();
45614596 mlir::Attribute bindNameAttr;
4562- if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4597+ if (const auto &bindSym{
4598+ std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4599+ bindNameAttr = builder.getSymbolRefAttr (converter.mangleName (*bindSym));
4600+ bindIdNames.push_back (bindNameAttr);
4601+ bindIdNameDeviceTypes.push_back (getDeviceTypeAttr ());
4602+ } else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
45634603 bindNameAttr = builder.getStringAttr (*bindStr);
4564- } else if (const auto &bindSym{
4565- std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4566- bindNameAttr = builder.getStringAttr (converter.mangleName (*bindSym));
4604+ bindStrNames.push_back (bindNameAttr);
4605+ bindStrNameDeviceTypes.push_back (getDeviceTypeAttr ());
45674606 } else {
45684607 llvm_unreachable (" Unsupported bind name type" );
45694608 }
4570- bindNames.push_back (bindNameAttr);
4571- bindNameDeviceTypes.push_back (getDeviceTypeAttr ());
45724609 }
45734610}
45744611
@@ -4584,8 +4621,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45844621 bool hasNohost{false };
45854622
45864623 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4587- workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4588- gangDimDeviceTypes, gangDimValues;
4624+ workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4625+ bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
4626+ gangDimValues;
45894627
45904628 for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
45914629 // Device Independent Attributes
@@ -4594,24 +4632,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45944632 }
45954633 // Note: Device Independent Attributes are set to the
45964634 // none device type in `info`.
4597- interpretRoutineDeviceInfo (converter, info, seqDeviceTypes,
4598- vectorDeviceTypes, workerDeviceTypes,
4599- bindNameDeviceTypes, bindNames, gangDeviceTypes ,
4600- gangDimValues, gangDimDeviceTypes);
4635+ interpretRoutineDeviceInfo (
4636+ converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
4637+ bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames ,
4638+ bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
46014639
46024640 // Device Dependent Attributes
46034641 for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
46044642 info.deviceTypeInfos ()) {
4605- interpretRoutineDeviceInfo (
4606- converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4607- workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4608- gangDimValues, gangDimDeviceTypes);
4643+ interpretRoutineDeviceInfo (converter, dinfo, seqDeviceTypes,
4644+ vectorDeviceTypes, workerDeviceTypes,
4645+ bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4646+ bindIdNames, bindStrNames, gangDeviceTypes,
4647+ gangDimValues, gangDimDeviceTypes);
46094648 }
46104649 }
46114650 createOpenACCRoutineConstruct (
4612- converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4613- bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4614- seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4651+ converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
4652+ bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4653+ gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
4654+ workerDeviceTypes, vectorDeviceTypes);
46154655}
46164656
46174657static void
0 commit comments