Skip to content

Commit f61dd65

Browse files
author
Delaram Talaashrafi
committed
[openacc][flang] Support SymbolRef and String bindName representations and device attribute in acc.routine
Based on the OpenACC specification — which states that if the bind name is given as an identifier it should be resolved according to the compiled language, and if given as a string it should be used unmodified — we introduce two distinct `bindName` representations for `acc routine` to handle each case appropriately: one as an array of `SymbolRefAttr` for identifiers and another as an array of `StringAttr` for strings. To ensure correct correspondence between bind names and devices, this patch also introduces two separate sets of device attributes. The routine operation is extended accordingly, along with the necessary updates to the OpenACC dialect and its lowering.
1 parent 82d7405 commit f61dd65

File tree

7 files changed

+210
-72
lines changed

7 files changed

+210
-72
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4396,10 +4396,35 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
43964396
return std::nullopt;
43974397
}
43984398

4399+
// Helper function to extract string value from bind name variant
4400+
static std::optional<llvm::StringRef> getBindNameStringValue(
4401+
const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4402+
&bindNameValue) {
4403+
if (!bindNameValue.has_value()) {
4404+
return std::nullopt;
4405+
}
4406+
4407+
return std::visit(
4408+
[](const auto &attr) -> std::optional<llvm::StringRef> {
4409+
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
4410+
mlir::StringAttr>) {
4411+
return attr.getValue();
4412+
} else if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
4413+
mlir::SymbolRefAttr>) {
4414+
return attr.getLeafReference();
4415+
} else {
4416+
return std::nullopt;
4417+
}
4418+
},
4419+
bindNameValue.value());
4420+
}
4421+
43994422
static bool compareDeviceTypeInfo(
44004423
mlir::acc::RoutineOp op,
4401-
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4402-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4424+
llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
4425+
llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
4426+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
4427+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
44034428
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
44044429
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
44054430
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4409,9 +4434,13 @@ static bool compareDeviceTypeInfo(
44094434
for (uint32_t dtypeInt = 0;
44104435
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
44114436
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
4412-
if (op.getBindNameValue(dtype) !=
4413-
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4414-
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4437+
auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype));
4438+
if (bindNameValue !=
4439+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4440+
bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
4441+
bindNameValue !=
4442+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4443+
bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
44154444
return false;
44164445
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
44174446
return false;
@@ -4458,8 +4487,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
44584487
void createOpenACCRoutineConstruct(
44594488
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
44604489
mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4461-
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4462-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4490+
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
4491+
llvm::SmallVector<mlir::Attribute> &bindStrNames,
4492+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4493+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
44634494
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
44644495
llvm::SmallVector<mlir::Attribute> &gangDimValues,
44654496
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4472,7 +4503,8 @@ void createOpenACCRoutineConstruct(
44724503
0) {
44734504
// If the routine is already specified with the same clauses, just skip
44744505
// the operation creation.
4475-
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
4506+
if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames,
4507+
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
44764508
gangDeviceTypes, gangDimValues,
44774509
gangDimDeviceTypes, seqDeviceTypes,
44784510
workerDeviceTypes, vectorDeviceTypes) &&
@@ -4489,8 +4521,10 @@ void createOpenACCRoutineConstruct(
44894521
modBuilder.create<mlir::acc::RoutineOp>(
44904522
loc, routineOpStr,
44914523
mlir::SymbolRefAttr::get(builder.getContext(), funcName),
4492-
getArrayAttrOrNull(builder, bindNames),
4493-
getArrayAttrOrNull(builder, bindNameDeviceTypes),
4524+
getArrayAttrOrNull(builder, bindIdNames),
4525+
getArrayAttrOrNull(builder, bindStrNames),
4526+
getArrayAttrOrNull(builder, bindIdNameDeviceTypes),
4527+
getArrayAttrOrNull(builder, bindStrNameDeviceTypes),
44944528
getArrayAttrOrNull(builder, workerDeviceTypes),
44954529
getArrayAttrOrNull(builder, vectorDeviceTypes),
44964530
getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
@@ -4507,8 +4541,10 @@ static void interpretRoutineDeviceInfo(
45074541
llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
45084542
llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
45094543
llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4510-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4511-
llvm::SmallVector<mlir::Attribute> &bindNames,
4544+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4545+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4546+
llvm::SmallVector<mlir::Attribute> &bindIdNames,
4547+
llvm::SmallVector<mlir::Attribute> &bindStrNames,
45124548
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
45134549
llvm::SmallVector<mlir::Attribute> &gangDimValues,
45144550
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4541,16 +4577,18 @@ static void interpretRoutineDeviceInfo(
45414577
if (dinfo.bindNameOpt().has_value()) {
45424578
const auto &bindName = dinfo.bindNameOpt().value();
45434579
mlir::Attribute bindNameAttr;
4544-
if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4580+
if (const auto &bindSym{
4581+
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4582+
bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym));
4583+
bindIdNames.push_back(bindNameAttr);
4584+
bindIdNameDeviceTypes.push_back(getDeviceTypeAttr());
4585+
} else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
45454586
bindNameAttr = builder.getStringAttr(*bindStr);
4546-
} else if (const auto &bindSym{
4547-
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4548-
bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
4587+
bindStrNames.push_back(bindNameAttr);
4588+
bindStrNameDeviceTypes.push_back(getDeviceTypeAttr());
45494589
} else {
45504590
llvm_unreachable("Unsupported bind name type");
45514591
}
4552-
bindNames.push_back(bindNameAttr);
4553-
bindNameDeviceTypes.push_back(getDeviceTypeAttr());
45544592
}
45554593
}
45564594

@@ -4566,8 +4604,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45664604
bool hasNohost{false};
45674605

45684606
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4569-
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4570-
gangDimDeviceTypes, gangDimValues;
4607+
workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4608+
bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
4609+
gangDimValues;
45714610

45724611
for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
45734612
// Device Independent Attributes
@@ -4576,24 +4615,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45764615
}
45774616
// Note: Device Independent Attributes are set to the
45784617
// none device type in `info`.
4579-
interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
4580-
vectorDeviceTypes, workerDeviceTypes,
4581-
bindNameDeviceTypes, bindNames, gangDeviceTypes,
4582-
gangDimValues, gangDimDeviceTypes);
4618+
interpretRoutineDeviceInfo(
4619+
converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
4620+
bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames,
4621+
bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
45834622

45844623
// Device Dependent Attributes
45854624
for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
45864625
info.deviceTypeInfos()) {
4587-
interpretRoutineDeviceInfo(
4588-
converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4589-
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4590-
gangDimValues, gangDimDeviceTypes);
4626+
interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes,
4627+
vectorDeviceTypes, workerDeviceTypes,
4628+
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4629+
bindIdNames, bindStrNames, gangDeviceTypes,
4630+
gangDimValues, gangDimDeviceTypes);
45914631
}
45924632
}
45934633
createOpenACCRoutineConstruct(
4594-
converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4595-
bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4596-
seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4634+
converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
4635+
bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4636+
gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
4637+
workerDeviceTypes, vectorDeviceTypes);
45974638
}
45984639

45994640
static void

flang/test/Lower/OpenACC/acc-routine.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
44

5-
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
6-
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
5+
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine17
6+
! [#acc.device_type<default>], @_QPacc_routine16 [#acc.device_type<multicore>])
7+
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine16 [#acc.device_type<multicore>])
78
! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
89
! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
910
! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq
1011
! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq
11-
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a")
12+
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a)
1213
! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_")
1314
! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64)
1415
! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost

flang/test/Lower/OpenACC/acc-routine03.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ subroutine sub2(a)
3030
end subroutine
3131

3232
! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost
33-
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker
33+
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker
3434
! CHECK: func.func @_QPsub1(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
3535
! CHECK: func.func @_QPsub2(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>}

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Interfaces/ControlFlowInterfaces.h"
3030
#include "mlir/Interfaces/LoopLikeInterface.h"
3131
#include "mlir/Interfaces/SideEffectInterfaces.h"
32+
#include <variant>
3233

3334
#define GET_TYPEDEF_CLASSES
3435
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc"

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
27722772
}];
27732773

27742774
let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name,
2775-
OptionalAttr<StrArrayAttr>:$bindName,
2776-
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
2775+
OptionalAttr<SymbolRefArrayAttr>:$bindIdName,
2776+
OptionalAttr<StrArrayAttr>:$bindStrName,
2777+
OptionalAttr<DeviceTypeArrayAttr>:$bindIdNameDeviceType,
2778+
OptionalAttr<DeviceTypeArrayAttr>:$bindStrNameDeviceType,
27772779
OptionalAttr<DeviceTypeArrayAttr>:$worker,
27782780
OptionalAttr<DeviceTypeArrayAttr>:$vector,
27792781
OptionalAttr<DeviceTypeArrayAttr>:$seq, UnitAttr:$nohost,
@@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
28152817
std::optional<int64_t> getGangDimValue();
28162818
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
28172819

2818-
std::optional<llvm::StringRef> getBindNameValue();
2819-
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
2820+
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue();
2821+
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType);
28202822
}];
28212823

28222824
let assemblyFormat = [{
28232825
$sym_name `func` `(` $func_name `)`
28242826
oilist (
2825-
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
2827+
`bind` `(` custom<BindName>($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)`
28262828
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
28272829
| `worker` custom<DeviceTypeArrayAttr>($worker)
28282830
| `vector` custom<DeviceTypeArrayAttr>($vector)

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallSet.h"
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/Support/LogicalResult.h"
24+
#include <variant>
2425

2526
using namespace mlir;
2627
using namespace acc;
@@ -3461,40 +3462,89 @@ LogicalResult acc::RoutineOp::verify() {
34613462
return success();
34623463
}
34633464

3464-
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3465-
mlir::ArrayAttr &deviceTypes) {
3466-
llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3467-
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3465+
static ParseResult parseBindName(OpAsmParser &parser,
3466+
mlir::ArrayAttr &bindIdName,
3467+
mlir::ArrayAttr &bindStrName,
3468+
mlir::ArrayAttr &deviceIdTypes,
3469+
mlir::ArrayAttr &deviceStrTypes) {
3470+
llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
3471+
llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
3472+
llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
3473+
llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
34683474

34693475
if (failed(parser.parseCommaSeparatedList([&]() {
3470-
if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3476+
mlir::Attribute newAttr;
3477+
bool isSymbolRefAttr;
3478+
auto parseResult = parser.parseAttribute(newAttr);
3479+
if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3480+
bindIdNameAttrs.push_back(symbolRefAttr);
3481+
isSymbolRefAttr = true;
3482+
} else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3483+
bindStrNameAttrs.push_back(stringAttr);
3484+
isSymbolRefAttr = false;
3485+
}
3486+
if (parseResult)
34713487
return failure();
34723488
if (failed(parser.parseOptionalLSquare())) {
3473-
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3474-
parser.getContext(), mlir::acc::DeviceType::None));
3489+
if (isSymbolRefAttr) {
3490+
deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3491+
parser.getContext(), mlir::acc::DeviceType::None));
3492+
} else {
3493+
deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3494+
parser.getContext(), mlir::acc::DeviceType::None));
3495+
}
34753496
} else {
3476-
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3477-
parser.parseRSquare())
3478-
return failure();
3497+
if (isSymbolRefAttr) {
3498+
if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
3499+
parser.parseRSquare())
3500+
return failure();
3501+
} else {
3502+
if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
3503+
parser.parseRSquare())
3504+
return failure();
3505+
}
34793506
}
34803507
return success();
34813508
})))
34823509
return failure();
34833510

3484-
bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
3485-
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3511+
bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
3512+
bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
3513+
deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
3514+
deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
34863515

34873516
return success();
34883517
}
34893518

34903519
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
3491-
std::optional<mlir::ArrayAttr> bindName,
3492-
std::optional<mlir::ArrayAttr> deviceTypes) {
3493-
llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3494-
[&](const auto &pair) {
3495-
p << std::get<0>(pair);
3496-
printSingleDeviceType(p, std::get<1>(pair));
3497-
});
3520+
std::optional<mlir::ArrayAttr> bindIdName,
3521+
std::optional<mlir::ArrayAttr> bindStrName,
3522+
std::optional<mlir::ArrayAttr> deviceIdTypes,
3523+
std::optional<mlir::ArrayAttr> deviceStrTypes) {
3524+
// Create combined vectors for all bind names and device types
3525+
llvm::SmallVector<mlir::Attribute> allBindNames;
3526+
llvm::SmallVector<mlir::Attribute> allDeviceTypes;
3527+
3528+
// Append bindIdName and deviceIdTypes
3529+
if (hasDeviceTypeValues(deviceIdTypes)) {
3530+
allBindNames.append(bindIdName->begin(), bindIdName->end());
3531+
allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
3532+
}
3533+
3534+
// Append bindStrName and deviceStrTypes
3535+
if (hasDeviceTypeValues(deviceStrTypes)) {
3536+
allBindNames.append(bindStrName->begin(), bindStrName->end());
3537+
allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
3538+
}
3539+
3540+
// Print the combined sequence
3541+
if (!allBindNames.empty()) {
3542+
llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
3543+
[&](const auto &pair) {
3544+
p << std::get<0>(pair);
3545+
printSingleDeviceType(p, std::get<1>(pair));
3546+
});
3547+
}
34983548
}
34993549

35003550
static ParseResult parseRoutineGangClause(OpAsmParser &parser,
@@ -3654,19 +3704,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
36543704
return hasDeviceType(getSeq(), deviceType);
36553705
}
36563706

3657-
std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3707+
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3708+
RoutineOp::getBindNameValue() {
36583709
return getBindNameValue(mlir::acc::DeviceType::None);
36593710
}
36603711

3661-
std::optional<llvm::StringRef>
3712+
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
36623713
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3663-
if (!hasDeviceTypeValues(getBindNameDeviceType()))
3714+
if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
3715+
!hasDeviceTypeValues(getBindStrNameDeviceType())) {
36643716
return std::nullopt;
3665-
if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
3666-
auto attr = (*getBindName())[*pos];
3717+
}
3718+
3719+
if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
3720+
auto attr = (*getBindIdName())[*pos];
3721+
auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
3722+
assert(symbolRefAttr && "expected SymbolRef");
3723+
return symbolRefAttr;
3724+
}
3725+
3726+
if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
3727+
auto attr = (*getBindStrName())[*pos];
36673728
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3668-
return stringAttr.getValue();
3729+
assert(stringAttr && "expected String");
3730+
return stringAttr;
36693731
}
3732+
36703733
return std::nullopt;
36713734
}
36723735

0 commit comments

Comments
 (0)