Skip to content

Commit 0dae924

Browse files
[openacc][flang] Support two type bindName representation in acc routine (llvm#149147)
Based on the OpenACC specification — which states that if the bind name is given as an identifier it should be resolved according to the compiled language, and if given as a string it should be used unmodified — we introduce two distinct `bindName` representations for `acc routine` to handle each case appropriately: one as an array of `SymbolRefAttr` for identifiers and another as an array of `StringAttr` for strings. To ensure correct correspondence between bind names and devices, this patch also introduces two separate sets of device attributes. The routine operation is extended accordingly, along with the necessary updates to the OpenACC dialect and its lowering.
1 parent 661cbd5 commit 0dae924

File tree

7 files changed

+208
-72
lines changed

7 files changed

+208
-72
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4414,10 +4414,34 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
44144414
return std::nullopt;
44154415
}
44164416

4417+
// Helper function to extract string value from bind name variant
4418+
static std::optional<llvm::StringRef> getBindNameStringValue(
4419+
const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4420+
&bindNameValue) {
4421+
if (!bindNameValue.has_value())
4422+
return std::nullopt;
4423+
4424+
return std::visit(
4425+
[](const auto &attr) -> std::optional<llvm::StringRef> {
4426+
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
4427+
mlir::StringAttr>) {
4428+
return attr.getValue();
4429+
} else if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
4430+
mlir::SymbolRefAttr>) {
4431+
return attr.getLeafReference();
4432+
} else {
4433+
return std::nullopt;
4434+
}
4435+
},
4436+
bindNameValue.value());
4437+
}
4438+
44174439
static bool compareDeviceTypeInfo(
44184440
mlir::acc::RoutineOp op,
4419-
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4420-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4441+
llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
4442+
llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
4443+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
4444+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
44214445
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
44224446
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
44234447
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4427,9 +4451,13 @@ static bool compareDeviceTypeInfo(
44274451
for (uint32_t dtypeInt = 0;
44284452
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
44294453
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
4430-
if (op.getBindNameValue(dtype) !=
4431-
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4432-
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4454+
auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype));
4455+
if (bindNameValue !=
4456+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4457+
bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
4458+
bindNameValue !=
4459+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4460+
bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
44334461
return false;
44344462
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
44354463
return false;
@@ -4476,8 +4504,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
44764504
void createOpenACCRoutineConstruct(
44774505
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
44784506
mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4479-
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4480-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4507+
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
4508+
llvm::SmallVector<mlir::Attribute> &bindStrNames,
4509+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4510+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
44814511
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
44824512
llvm::SmallVector<mlir::Attribute> &gangDimValues,
44834513
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4490,7 +4520,8 @@ void createOpenACCRoutineConstruct(
44904520
0) {
44914521
// If the routine is already specified with the same clauses, just skip
44924522
// the operation creation.
4493-
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
4523+
if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames,
4524+
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
44944525
gangDeviceTypes, gangDimValues,
44954526
gangDimDeviceTypes, seqDeviceTypes,
44964527
workerDeviceTypes, vectorDeviceTypes) &&
@@ -4507,8 +4538,10 @@ void createOpenACCRoutineConstruct(
45074538
modBuilder.create<mlir::acc::RoutineOp>(
45084539
loc, routineOpStr,
45094540
mlir::SymbolRefAttr::get(builder.getContext(), funcName),
4510-
getArrayAttrOrNull(builder, bindNames),
4511-
getArrayAttrOrNull(builder, bindNameDeviceTypes),
4541+
getArrayAttrOrNull(builder, bindIdNames),
4542+
getArrayAttrOrNull(builder, bindStrNames),
4543+
getArrayAttrOrNull(builder, bindIdNameDeviceTypes),
4544+
getArrayAttrOrNull(builder, bindStrNameDeviceTypes),
45124545
getArrayAttrOrNull(builder, workerDeviceTypes),
45134546
getArrayAttrOrNull(builder, vectorDeviceTypes),
45144547
getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
@@ -4525,8 +4558,10 @@ static void interpretRoutineDeviceInfo(
45254558
llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
45264559
llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
45274560
llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4528-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4529-
llvm::SmallVector<mlir::Attribute> &bindNames,
4561+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4562+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4563+
llvm::SmallVector<mlir::Attribute> &bindIdNames,
4564+
llvm::SmallVector<mlir::Attribute> &bindStrNames,
45304565
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
45314566
llvm::SmallVector<mlir::Attribute> &gangDimValues,
45324567
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4559,16 +4594,18 @@ static void interpretRoutineDeviceInfo(
45594594
if (dinfo.bindNameOpt().has_value()) {
45604595
const auto &bindName = dinfo.bindNameOpt().value();
45614596
mlir::Attribute bindNameAttr;
4562-
if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4597+
if (const auto &bindSym{
4598+
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4599+
bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym));
4600+
bindIdNames.push_back(bindNameAttr);
4601+
bindIdNameDeviceTypes.push_back(getDeviceTypeAttr());
4602+
} else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
45634603
bindNameAttr = builder.getStringAttr(*bindStr);
4564-
} else if (const auto &bindSym{
4565-
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4566-
bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
4604+
bindStrNames.push_back(bindNameAttr);
4605+
bindStrNameDeviceTypes.push_back(getDeviceTypeAttr());
45674606
} else {
45684607
llvm_unreachable("Unsupported bind name type");
45694608
}
4570-
bindNames.push_back(bindNameAttr);
4571-
bindNameDeviceTypes.push_back(getDeviceTypeAttr());
45724609
}
45734610
}
45744611

@@ -4584,8 +4621,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45844621
bool hasNohost{false};
45854622

45864623
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4587-
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4588-
gangDimDeviceTypes, gangDimValues;
4624+
workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4625+
bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
4626+
gangDimValues;
45894627

45904628
for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
45914629
// Device Independent Attributes
@@ -4594,24 +4632,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45944632
}
45954633
// Note: Device Independent Attributes are set to the
45964634
// none device type in `info`.
4597-
interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
4598-
vectorDeviceTypes, workerDeviceTypes,
4599-
bindNameDeviceTypes, bindNames, gangDeviceTypes,
4600-
gangDimValues, gangDimDeviceTypes);
4635+
interpretRoutineDeviceInfo(
4636+
converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
4637+
bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames,
4638+
bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
46014639

46024640
// Device Dependent Attributes
46034641
for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
46044642
info.deviceTypeInfos()) {
4605-
interpretRoutineDeviceInfo(
4606-
converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4607-
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4608-
gangDimValues, gangDimDeviceTypes);
4643+
interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes,
4644+
vectorDeviceTypes, workerDeviceTypes,
4645+
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4646+
bindIdNames, bindStrNames, gangDeviceTypes,
4647+
gangDimValues, gangDimDeviceTypes);
46094648
}
46104649
}
46114650
createOpenACCRoutineConstruct(
4612-
converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4613-
bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4614-
seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4651+
converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
4652+
bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4653+
gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
4654+
workerDeviceTypes, vectorDeviceTypes);
46154655
}
46164656

46174657
static void

flang/test/Lower/OpenACC/acc-routine.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
44

5-
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
6-
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
5+
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine17
6+
! [#acc.device_type<default>], @_QPacc_routine16 [#acc.device_type<multicore>])
7+
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine16 [#acc.device_type<multicore>])
78
! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
89
! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
910
! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq
1011
! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq
11-
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a")
12+
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a)
1213
! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_")
1314
! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64)
1415
! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost

flang/test/Lower/OpenACC/acc-routine03.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ subroutine sub2(a)
3030
end subroutine
3131

3232
! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost
33-
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker
33+
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker
3434
! CHECK: func.func @_QPsub1(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
3535
! CHECK: func.func @_QPsub2(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>}

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Interfaces/ControlFlowInterfaces.h"
3030
#include "mlir/Interfaces/LoopLikeInterface.h"
3131
#include "mlir/Interfaces/SideEffectInterfaces.h"
32+
#include <variant>
3233

3334
#define GET_TYPEDEF_CLASSES
3435
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc"

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
27722772
}];
27732773

27742774
let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name,
2775-
OptionalAttr<StrArrayAttr>:$bindName,
2776-
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
2775+
OptionalAttr<SymbolRefArrayAttr>:$bindIdName,
2776+
OptionalAttr<StrArrayAttr>:$bindStrName,
2777+
OptionalAttr<DeviceTypeArrayAttr>:$bindIdNameDeviceType,
2778+
OptionalAttr<DeviceTypeArrayAttr>:$bindStrNameDeviceType,
27772779
OptionalAttr<DeviceTypeArrayAttr>:$worker,
27782780
OptionalAttr<DeviceTypeArrayAttr>:$vector,
27792781
OptionalAttr<DeviceTypeArrayAttr>:$seq, UnitAttr:$nohost,
@@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
28152817
std::optional<int64_t> getGangDimValue();
28162818
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
28172819

2818-
std::optional<llvm::StringRef> getBindNameValue();
2819-
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
2820+
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue();
2821+
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType);
28202822
}];
28212823

28222824
let assemblyFormat = [{
28232825
$sym_name `func` `(` $func_name `)`
28242826
oilist (
2825-
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
2827+
`bind` `(` custom<BindName>($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)`
28262828
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
28272829
| `worker` custom<DeviceTypeArrayAttr>($worker)
28282830
| `vector` custom<DeviceTypeArrayAttr>($vector)

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallSet.h"
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/Support/LogicalResult.h"
24+
#include <variant>
2425

2526
using namespace mlir;
2627
using namespace acc;
@@ -3461,40 +3462,88 @@ LogicalResult acc::RoutineOp::verify() {
34613462
return success();
34623463
}
34633464

3464-
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3465-
mlir::ArrayAttr &deviceTypes) {
3466-
llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3467-
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3465+
static ParseResult parseBindName(OpAsmParser &parser,
3466+
mlir::ArrayAttr &bindIdName,
3467+
mlir::ArrayAttr &bindStrName,
3468+
mlir::ArrayAttr &deviceIdTypes,
3469+
mlir::ArrayAttr &deviceStrTypes) {
3470+
llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
3471+
llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
3472+
llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
3473+
llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
34683474

34693475
if (failed(parser.parseCommaSeparatedList([&]() {
3470-
if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3476+
mlir::Attribute newAttr;
3477+
bool isSymbolRefAttr;
3478+
auto parseResult = parser.parseAttribute(newAttr);
3479+
if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3480+
bindIdNameAttrs.push_back(symbolRefAttr);
3481+
isSymbolRefAttr = true;
3482+
} else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3483+
bindStrNameAttrs.push_back(stringAttr);
3484+
isSymbolRefAttr = false;
3485+
}
3486+
if (parseResult)
34713487
return failure();
34723488
if (failed(parser.parseOptionalLSquare())) {
3473-
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3474-
parser.getContext(), mlir::acc::DeviceType::None));
3489+
if (isSymbolRefAttr) {
3490+
deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3491+
parser.getContext(), mlir::acc::DeviceType::None));
3492+
} else {
3493+
deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3494+
parser.getContext(), mlir::acc::DeviceType::None));
3495+
}
34753496
} else {
3476-
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3477-
parser.parseRSquare())
3478-
return failure();
3497+
if (isSymbolRefAttr) {
3498+
if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
3499+
parser.parseRSquare())
3500+
return failure();
3501+
} else {
3502+
if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
3503+
parser.parseRSquare())
3504+
return failure();
3505+
}
34793506
}
34803507
return success();
34813508
})))
34823509
return failure();
34833510

3484-
bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
3485-
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3511+
bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
3512+
bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
3513+
deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
3514+
deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
34863515

34873516
return success();
34883517
}
34893518

34903519
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
3491-
std::optional<mlir::ArrayAttr> bindName,
3492-
std::optional<mlir::ArrayAttr> deviceTypes) {
3493-
llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3494-
[&](const auto &pair) {
3495-
p << std::get<0>(pair);
3496-
printSingleDeviceType(p, std::get<1>(pair));
3497-
});
3520+
std::optional<mlir::ArrayAttr> bindIdName,
3521+
std::optional<mlir::ArrayAttr> bindStrName,
3522+
std::optional<mlir::ArrayAttr> deviceIdTypes,
3523+
std::optional<mlir::ArrayAttr> deviceStrTypes) {
3524+
// Create combined vectors for all bind names and device types
3525+
llvm::SmallVector<mlir::Attribute> allBindNames;
3526+
llvm::SmallVector<mlir::Attribute> allDeviceTypes;
3527+
3528+
// Append bindIdName and deviceIdTypes
3529+
if (hasDeviceTypeValues(deviceIdTypes)) {
3530+
allBindNames.append(bindIdName->begin(), bindIdName->end());
3531+
allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
3532+
}
3533+
3534+
// Append bindStrName and deviceStrTypes
3535+
if (hasDeviceTypeValues(deviceStrTypes)) {
3536+
allBindNames.append(bindStrName->begin(), bindStrName->end());
3537+
allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
3538+
}
3539+
3540+
// Print the combined sequence
3541+
if (!allBindNames.empty())
3542+
llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
3543+
[&](const auto &pair) {
3544+
p << std::get<0>(pair);
3545+
printSingleDeviceType(p, std::get<1>(pair));
3546+
});
34983547
}
34993548

35003549
static ParseResult parseRoutineGangClause(OpAsmParser &parser,
@@ -3654,19 +3703,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
36543703
return hasDeviceType(getSeq(), deviceType);
36553704
}
36563705

3657-
std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3706+
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3707+
RoutineOp::getBindNameValue() {
36583708
return getBindNameValue(mlir::acc::DeviceType::None);
36593709
}
36603710

3661-
std::optional<llvm::StringRef>
3711+
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
36623712
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3663-
if (!hasDeviceTypeValues(getBindNameDeviceType()))
3713+
if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
3714+
!hasDeviceTypeValues(getBindStrNameDeviceType())) {
36643715
return std::nullopt;
3665-
if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
3666-
auto attr = (*getBindName())[*pos];
3716+
}
3717+
3718+
if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
3719+
auto attr = (*getBindIdName())[*pos];
3720+
auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
3721+
assert(symbolRefAttr && "expected SymbolRef");
3722+
return symbolRefAttr;
3723+
}
3724+
3725+
if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
3726+
auto attr = (*getBindStrName())[*pos];
36673727
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3668-
return stringAttr.getValue();
3728+
assert(stringAttr && "expected String");
3729+
return stringAttr;
36693730
}
3731+
36703732
return std::nullopt;
36713733
}
36723734

0 commit comments

Comments
 (0)