Skip to content

Commit ba151e8

Browse files
committed
[CUDA][HIP] Fix host/device context in concept
Currently, constraints are checked in Sema::FinishTemplateArgumentDeduction, where the current function in ASTContext is set to the instantiated template function. When resolving functions for the constraints, clang assumes the caller is the current function, This causes incompatibility with nvcc and also for constexpr template functions with C++. clang caches the constraint checking result per concept/type matching. It assumes the result does not depend on the instantiation context. This patch let constraint checking have its own host/device context and by default it is host to be compatible with C++. This makes the constraint checking independent of callers and make the caching valid. In the future, we may introduce device constraints by other means, e.g. adding __device__ attribute per function call in constraints. Fixes: #67507
1 parent 3231a36 commit ba151e8

File tree

5 files changed

+83
-15
lines changed

5 files changed

+83
-15
lines changed

clang/docs/HIPSupport.rst

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,34 @@ Predefined Macros
176176
* - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
177177
- Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
178178

179+
C++20 Concepts with HIP and CUDA
180+
--------------------------------
181+
182+
In Clang, when working with HIP or CUDA, it's important to note that all constraints in C++20 concepts are assumed to be for the host side only. This behavior is consistent across both programming models, and developers should be aware of this assumption when writing code that utilizes C++20 concepts.
183+
184+
Example:
185+
.. code-block:: c++
186+
187+
template <class T>
188+
concept MyConcept = requires(T& obj) {
189+
my_function(obj); // Assumed to be a host-side requirement
190+
};
191+
192+
template <MyConcept T>
193+
__global__ void kernel() {
194+
// Kernel code
195+
}
196+
197+
struct MyType {};
198+
199+
inline void my_function(MyType& obj) {}
200+
201+
int main() {
202+
kernel<MyType><<<1,1>>>();
203+
return 0;
204+
}
205+
206+
In the above example, the ``MyConcept`` concept is assumed to check the host-side requirements, even though it's being used in a device kernel. Developers should structure their code accordingly to ensure correct behavior and to satisfy the host-side constraints assumed by Clang.
207+
208+
This assumption helps maintain a consistent behavior when dealing with template constraints, and simplifies the compilation model by reducing the complexity associated with differentiating between host and device-side requirements.
209+

clang/include/clang/Sema/Sema.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13312,20 +13312,25 @@ class Sema final {
1331213312
CTCK_Unknown, /// Unknown context
1331313313
CTCK_InitGlobalVar, /// Function called during global variable
1331413314
/// initialization
13315+
CTCK_Constraint, /// Function called for constraint checking
1331513316
};
1331613317

1331713318
/// Define the current global CUDA host/device context where a function may be
1331813319
/// called. Only used when a function is called outside of any functions.
1331913320
struct CUDATargetContext {
1332013321
CUDAFunctionTarget Target = CFT_HostDevice;
1332113322
CUDATargetContextKind Kind = CTCK_Unknown;
13322-
Decl *D = nullptr;
13323+
const Decl *D = nullptr;
13324+
const Expr *E = nullptr;
13325+
/// Whether should override the current function.
13326+
bool shouldOverride(const Decl *D) const;
1332313327
} CurCUDATargetCtx;
1332413328

1332513329
struct CUDATargetContextRAII {
1332613330
Sema &S;
1332713331
CUDATargetContext SavedCtx;
13328-
CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
13332+
CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, const Decl *D,
13333+
const Expr *E = nullptr);
1332913334
~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
1333013335
};
1333113336

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,27 +114,34 @@ static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
114114

115115
Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
116116
CUDATargetContextKind K,
117-
Decl *D)
117+
const Decl *D, const Expr *E)
118118
: S(S_) {
119119
SavedCtx = S.CurCUDATargetCtx;
120-
assert(K == CTCK_InitGlobalVar);
121-
auto *VD = dyn_cast_or_null<VarDecl>(D);
122-
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
123-
auto Target = CFT_Host;
124-
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
125-
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
126-
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
127-
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
128-
Target = CFT_Device;
129-
S.CurCUDATargetCtx = {Target, K, VD};
120+
auto Target = CFT_Host;
121+
if (K == CTCK_InitGlobalVar) {
122+
auto *VD = dyn_cast_or_null<VarDecl>(D);
123+
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
124+
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
125+
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
126+
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
127+
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
128+
Target = CFT_Device;
129+
S.CurCUDATargetCtx = {Target, K, D, E};
130+
}
131+
return;
130132
}
133+
assert(K == CTCK_Constraint);
134+
S.CurCUDATargetCtx = {Target, K, D, E};
135+
}
136+
137+
bool Sema::CUDATargetContext::shouldOverride(const Decl *D) const {
138+
return Kind == CTCK_Constraint || D == nullptr;
131139
}
132140

133141
/// IdentifyCUDATarget - Determine the CUDA compilation target for this function
134142
Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
135143
bool IgnoreImplicitHDAttr) {
136-
// Code that lives outside a function gets the target from CurCUDATargetCtx.
137-
if (D == nullptr)
144+
if (CurCUDATargetCtx.shouldOverride(D))
138145
return CurCUDATargetCtx.Target;
139146

140147
if (D->hasAttr<CUDAInvalidTargetAttr>())

clang/lib/Sema/SemaConcept.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ static ExprResult calculateConstraintSatisfaction(
336336
Sema &S, const NamedDecl *Template, SourceLocation TemplateNameLoc,
337337
const MultiLevelTemplateArgumentList &MLTAL, const Expr *ConstraintExpr,
338338
ConstraintSatisfaction &Satisfaction) {
339+
Sema::CUDATargetContextRAII X(S, Sema::CTCK_Constraint,
340+
/*Decl=*/nullptr, ConstraintExpr);
339341
return calculateConstraintSatisfaction(
340342
S, ConstraintExpr, Satisfaction, [&](const Expr *AtomicExpr) {
341343
EnterExpressionEvaluationContext ConstantEvaluated(

clang/test/SemaCUDA/concept.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -fcuda-is-device -x hip %s \
2+
// RUN: -std=c++20 -fsyntax-only -verify
3+
// RUN: %clang_cc1 -triple x86_64 -x hip %s \
4+
// RUN: -std=c++20 -fsyntax-only -verify
5+
6+
// expected-no-diagnostics
7+
8+
#include "Inputs/cuda.h"
9+
10+
template <class T>
11+
concept C = requires(T x) {
12+
func(x);
13+
};
14+
15+
struct A {};
16+
void func(A x) {}
17+
18+
template <C T> __global__ void kernel(T x) { }
19+
20+
int main() {
21+
A a;
22+
kernel<<<1,1>>>(a);
23+
}

0 commit comments

Comments
 (0)