Skip to content

Commit b28384d

Browse files
committed
[CUDA][HIP] Fix deduction guide
Currently clang assumes implicit deduction guide to be host device. This generates two identical implicit deduction guides when a class have a device and a host constructor which have the same input parameter, which causes ambiguity. Since an implicit deduction guide is derived from a constructor, it should take the same host/device attribute as the originating constructor. This matches nvcc behavior as seen in https://godbolt.org/z/sY1vdYWKe and https://godbolt.org/z/vTer7xa3j
1 parent 0b2924a commit b28384d

File tree

6 files changed

+179
-14
lines changed

6 files changed

+179
-14
lines changed

clang/docs/HIPSupport.rst

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,58 @@ Predefined Macros
176176
* - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
177177
- Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
178178

179+
Support for Deduction Guides
180+
============================
181+
182+
Explicit Deduction Guides
183+
-------------------------
184+
185+
Explicit deduction guides in HIP can be annotated with either the
186+
``__host__`` or ``__device__`` attributes. If no attribute is provided,
187+
it defaults to ``__host__``.
188+
189+
.. code-block:: cpp
190+
191+
template <typename T>
192+
class MyArray {
193+
//...
194+
};
195+
196+
template <typename T>
197+
MyArray(T)->MyArray<T>;
198+
199+
__device__ MyArray(float)->MyArray<int>;
200+
201+
// Uses of the deduction guides
202+
MyArray arr1 = 10; // Uses the default host guide
203+
__device__ void foo() {
204+
MyArray arr2 = 3.14f; // Uses the device guide
205+
}
206+
207+
Implicit Deduction Guides
208+
-------------------------
209+
Implicit deduction guides derived from constructors inherit the same host or
210+
device attributes as the originating constructor.
211+
212+
.. code-block:: cpp
213+
214+
template <typename T>
215+
class MyVector {
216+
public:
217+
__device__ MyVector(T) { /* ... */ }
218+
//...
219+
};
220+
221+
// The implicit deduction guide for MyVector will be `__device__` due to the device constructor
222+
223+
__device__ void foo() {
224+
MyVector vec(42); // Uses the implicit device guide derived from the constructor
225+
}
226+
227+
Availability Checks
228+
--------------------
229+
When a deduction guide (either explicit or implicit) is used, HIP checks its
230+
availability based on its host/device attributes and the context in a similar
231+
way as checking a function. Utilizing a deduction guide in an incompatible context
232+
results in a compile-time error.
233+

clang/include/clang/AST/DeclCXX.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,14 +1948,7 @@ class CXXDeductionGuideDecl : public FunctionDecl {
19481948
ExplicitSpecifier ES,
19491949
const DeclarationNameInfo &NameInfo, QualType T,
19501950
TypeSourceInfo *TInfo, SourceLocation EndLocation,
1951-
CXXConstructorDecl *Ctor, DeductionCandidate Kind)
1952-
: FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
1953-
SC_None, false, false, ConstexprSpecKind::Unspecified),
1954-
Ctor(Ctor), ExplicitSpec(ES) {
1955-
if (EndLocation.isValid())
1956-
setRangeEnd(EndLocation);
1957-
setDeductionCandidateKind(Kind);
1958-
}
1951+
CXXConstructorDecl *Ctor, DeductionCandidate Kind);
19591952

19601953
CXXConstructorDecl *Ctor;
19611954
ExplicitSpecifier ExplicitSpec;

clang/lib/AST/DeclCXX.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,6 +2113,27 @@ ExplicitSpecifier ExplicitSpecifier::getFromDecl(FunctionDecl *Function) {
21132113
}
21142114
}
21152115

2116+
CXXDeductionGuideDecl::CXXDeductionGuideDecl(
2117+
ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
2118+
ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
2119+
TypeSourceInfo *TInfo, SourceLocation EndLocation, CXXConstructorDecl *Ctor,
2120+
DeductionCandidate Kind)
2121+
: FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
2122+
SC_None, false, false, ConstexprSpecKind::Unspecified),
2123+
Ctor(Ctor), ExplicitSpec(ES) {
2124+
if (EndLocation.isValid())
2125+
setRangeEnd(EndLocation);
2126+
setDeductionCandidateKind(Kind);
2127+
// If Ctor is not nullptr, this deduction guide is implicitly derived from
2128+
// the ctor, therefore it should have the same host/device attribute.
2129+
if (Ctor && C.getLangOpts().CUDA) {
2130+
if (Ctor->hasAttr<CUDAHostAttr>())
2131+
this->addAttr(CUDAHostAttr::CreateImplicit(C));
2132+
if (Ctor->hasAttr<CUDADeviceAttr>())
2133+
this->addAttr(CUDADeviceAttr::CreateImplicit(C));
2134+
}
2135+
}
2136+
21162137
CXXDeductionGuideDecl *CXXDeductionGuideDecl::Create(
21172138
ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
21182139
ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,13 @@ Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
149149
return CFT_Device;
150150
} else if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
151151
return CFT_Host;
152-
} else if ((D->isImplicit() || !D->isUserProvided()) &&
152+
} else if (!isa<CXXDeductionGuideDecl>(D) &&
153+
(D->isImplicit() || !D->isUserProvided()) &&
153154
!IgnoreImplicitHDAttr) {
154155
// Some implicit declarations (like intrinsic functions) are not marked.
155156
// Set the most lenient target on them for maximal flexibility.
157+
// Implicit deduction duides are derived from constructors and their
158+
// host/device attributes are determined by their originating constructors.
156159
return CFT_HostDevice;
157160
}
158161

clang/lib/Sema/SemaTemplate.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,19 +2685,27 @@ void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,
26852685
AddedAny = true;
26862686
}
26872687

2688+
// Build simple deduction guide and set CUDA host/device attributes.
2689+
auto BuildSimpleDeductionGuide = [&](auto T) {
2690+
auto *DG = cast<CXXDeductionGuideDecl>(
2691+
cast<FunctionTemplateDecl>(Transform.buildSimpleDeductionGuide(T))
2692+
->getTemplatedDecl());
2693+
if (LangOpts.CUDA) {
2694+
DG->addAttr(CUDAHostAttr::CreateImplicit(getASTContext()));
2695+
DG->addAttr(CUDADeviceAttr::CreateImplicit(getASTContext()));
2696+
}
2697+
return DG;
2698+
};
26882699
// C++17 [over.match.class.deduct]
26892700
// -- If C is not defined or does not declare any constructors, an
26902701
// additional function template derived as above from a hypothetical
26912702
// constructor C().
26922703
if (!AddedAny)
2693-
Transform.buildSimpleDeductionGuide(std::nullopt);
2704+
BuildSimpleDeductionGuide(std::nullopt);
26942705

26952706
// -- An additional function template derived as above from a hypothetical
26962707
// constructor C(C), called the copy deduction candidate.
2697-
cast<CXXDeductionGuideDecl>(
2698-
cast<FunctionTemplateDecl>(
2699-
Transform.buildSimpleDeductionGuide(Transform.DeducedType))
2700-
->getTemplatedDecl())
2708+
BuildSimpleDeductionGuide(Transform.DeducedType)
27012709
->setDeductionCandidateKind(DeductionCandidate::Copy);
27022710
}
27032711

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: %clang_cc1 -fsyntax-only -verify=expected,host %s
2+
// RUN: %clang_cc1 -fcuda-is-device -fsyntax-only -verify=expected,dev %s
3+
4+
#include "Inputs/cuda.h"
5+
6+
// Implicit deduction guide for host.
7+
template <typename T>
8+
struct HGuideImp { // expected-note {{candidate template ignored: could not match 'HGuideImp<T>' against 'int'}}
9+
HGuideImp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
10+
// dev-note@-1 {{'<deduction guide for HGuideImp><int>' declared here}}
11+
// dev-note@-2 {{'HGuideImp' declared here}}
12+
};
13+
14+
// Explicit deduction guide for host.
15+
template <typename T>
16+
struct HGuideExp { // expected-note {{candidate template ignored: could not match 'HGuideExp<T>' against 'int'}}
17+
HGuideExp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
18+
// dev-note@-1 {{'HGuideExp' declared here}}
19+
};
20+
template<typename T>
21+
HGuideExp(T) -> HGuideExp<T>; // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
22+
// dev-note@-1 {{'<deduction guide for HGuideExp><int>' declared here}}
23+
24+
// Implicit deduction guide for device.
25+
template <typename T>
26+
struct DGuideImp { // expected-note {{candidate template ignored: could not match 'DGuideImp<T>' against 'int'}}
27+
__device__ DGuideImp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
28+
// host-note@-1 {{'<deduction guide for DGuideImp><int>' declared here}}
29+
// host-note@-2 {{'DGuideImp' declared here}}
30+
};
31+
32+
// Explicit deduction guide for device.
33+
template <typename T>
34+
struct DGuideExp { // expected-note {{candidate template ignored: could not match 'DGuideExp<T>' against 'int'}}
35+
__device__ DGuideExp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
36+
// host-note@-1 {{'DGuideExp' declared here}}
37+
};
38+
39+
template<typename T>
40+
__device__ DGuideExp(T) -> DGuideExp<T>; // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
41+
// host-note@-1 {{'<deduction guide for DGuideExp><int>' declared here}}
42+
43+
template <typename T>
44+
struct HDGuide {
45+
__device__ HDGuide(T value) {}
46+
HDGuide(T value) {}
47+
};
48+
49+
template<typename T>
50+
HDGuide(T) -> HDGuide<T>;
51+
52+
template<typename T>
53+
__device__ HDGuide(T) -> HDGuide<T>;
54+
55+
void hfun() {
56+
HGuideImp hgi = 10;
57+
HGuideExp hge = 10;
58+
DGuideImp dgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideImp'}}
59+
DGuideExp dge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideExp'}}
60+
HDGuide hdg = 10;
61+
}
62+
63+
__device__ void dfun() {
64+
HGuideImp hgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideImp'}}
65+
HGuideExp hge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideExp'}}
66+
DGuideImp dgi = 10;
67+
DGuideExp dge = 10;
68+
HDGuide hdg = 10;
69+
}
70+
71+
__host__ __device__ void hdfun() {
72+
HGuideImp hgi = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideImp><int>' in __host__ __device__ function}}
73+
// dev-error@-1 {{reference to __host__ function 'HGuideImp' in __host__ __device__ function}}
74+
HGuideExp hge = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideExp><int>' in __host__ __device__ function}}
75+
// dev-error@-1 {{reference to __host__ function 'HGuideExp' in __host__ __device__ function}}
76+
DGuideImp dgi = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideImp><int>' in __host__ __device__ function}}
77+
// host-error@-1 {{reference to __device__ function 'DGuideImp' in __host__ __device__ function}}
78+
DGuideExp dge = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideExp><int>' in __host__ __device__ function}}
79+
// host-error@-1 {{reference to __device__ function 'DGuideExp' in __host__ __device__ function}}
80+
HDGuide hdg = 10;
81+
}
82+
83+
HGuideImp hgi = 10;
84+
HGuideExp hge = 10;
85+
HDGuide hdg = 10;

0 commit comments

Comments
 (0)