Skip to content

Commit c189df8

Browse files
authored
[flang][cuda] Fix resolution of overloaded operator (#122402)
1 parent dab6463 commit c189df8

File tree

2 files changed

+38
-20
lines changed

2 files changed

+38
-20
lines changed

flang/lib/Semantics/resolve-names.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8970,18 +8970,6 @@ void ResolveNamesVisitor::FinishSpecificationPart(
89708970
misparsedStmtFuncFound_ = false;
89718971
funcResultStack().CompleteFunctionResultType();
89728972
CheckImports();
8973-
bool inDeviceSubprogram = false;
8974-
if (auto *subp{currScope().symbol()
8975-
? currScope().symbol()->detailsIf<SubprogramDetails>()
8976-
: nullptr}) {
8977-
if (auto attrs{subp->cudaSubprogramAttrs()}) {
8978-
if (*attrs == common::CUDASubprogramAttrs::Device ||
8979-
*attrs == common::CUDASubprogramAttrs::Global ||
8980-
*attrs == common::CUDASubprogramAttrs::Grid_Global) {
8981-
inDeviceSubprogram = true;
8982-
}
8983-
}
8984-
}
89858973
for (auto &pair : currScope()) {
89868974
auto &symbol{*pair.second};
89878975
if (inInterfaceBlock()) {
@@ -8990,14 +8978,6 @@ void ResolveNamesVisitor::FinishSpecificationPart(
89908978
if (NeedsExplicitType(symbol)) {
89918979
ApplyImplicitRules(symbol);
89928980
}
8993-
if (inDeviceSubprogram && symbol.has<ObjectEntityDetails>()) {
8994-
auto *object{symbol.detailsIf<ObjectEntityDetails>()};
8995-
if (!object->cudaDataAttr() && !IsValue(symbol) &&
8996-
(IsDummy(symbol) || object->IsArray())) {
8997-
// Implicitly set device attribute if none is set in device context.
8998-
object->set_cudaDataAttr(common::CUDADataAttr::Device);
8999-
}
9000-
}
90018981
if (IsDummy(symbol) && isImplicitNoneType() &&
90028982
symbol.test(Symbol::Flag::Implicit) && !context().HasError(symbol)) {
90038983
Say(symbol.name(),
@@ -9522,6 +9502,7 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
95229502
},
95239503
node.stmt());
95249504
Walk(node.spec());
9505+
bool inDeviceSubprogram = false;
95259506
// If this is a function, convert result to an object. This is to prevent the
95269507
// result from being converted later to a function symbol if it is called
95279508
// inside the function.
@@ -9535,6 +9516,15 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
95359516
if (details->isFunction()) {
95369517
ConvertToObjectEntity(const_cast<Symbol &>(details->result()));
95379518
}
9519+
// Check the current procedure is a device procedure to apply implicit
9520+
// attribute at the end.
9521+
if (auto attrs{details->cudaSubprogramAttrs()}) {
9522+
if (*attrs == common::CUDASubprogramAttrs::Device ||
9523+
*attrs == common::CUDASubprogramAttrs::Global ||
9524+
*attrs == common::CUDASubprogramAttrs::Grid_Global) {
9525+
inDeviceSubprogram = true;
9526+
}
9527+
}
95389528
}
95399529
}
95409530
if (node.IsModule()) {
@@ -9561,6 +9551,15 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
95619551
symbol.GetType() ? Symbol::Flag::Function : Symbol::Flag::Subroutine);
95629552
}
95639553
ApplyImplicitRules(symbol);
9554+
// Apply CUDA implicit attributes if needed.
9555+
if (inDeviceSubprogram && symbol.has<ObjectEntityDetails>()) {
9556+
auto *object{symbol.detailsIf<ObjectEntityDetails>()};
9557+
if (!object->cudaDataAttr() && !IsValue(symbol) &&
9558+
(IsDummy(symbol) || object->IsArray())) {
9559+
// Implicitly set device attribute if none is set in device context.
9560+
object->set_cudaDataAttr(common::CUDADataAttr::Device);
9561+
}
9562+
}
95649563
}
95659564
}
95669565

flang/test/Semantics/cuf10.cuf

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ module m
33
real, device :: a(4,8)
44
real, managed, allocatable :: b(:,:)
55
integer, constant :: x = 1
6+
type :: int
7+
real :: i, s
8+
end type int
9+
interface operator (+)
10+
module procedure addHost
11+
module procedure addDevice
12+
end interface operator (+)
613
contains
714
attributes(global) subroutine kernel(a,b,c,n,m)
815
integer, value :: n
@@ -30,4 +37,16 @@ module m
3037
subroutine sub2()
3138
call sub1<<<1,1>>>(x) ! actual constant to device dummy
3239
end
40+
function addHost(a, b) result(c)
41+
type(int), intent(in) :: a, b
42+
type(int) :: c
43+
end function addHost
44+
attributes(device) function addDevice(a, b) result(c)
45+
type(int), device :: c
46+
type(int), intent(in) :: a ,b
47+
end function addDevice
48+
attributes(global) subroutine overload(c, a, b)
49+
type (int) :: c, a, b
50+
c = a+b ! ok resolve to addDevice
51+
end subroutine overload
3352
end

0 commit comments

Comments
 (0)