Skip to content

Commit 0dcb0d6

Browse files
committed
[SYCL] Allow free function kernel args be templated on integer expressions (#20187)
`constexpr` variables are not forward-declarable so if one is used as a template parameter of a free function kernel argument, we cannot reference the variable, but must inline the value into the integration header.
1 parent 176dba1 commit 0dcb0d6

File tree

2 files changed

+145
-9
lines changed

2 files changed

+145
-9
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6597,10 +6597,12 @@ class FreeFunctionPrinter {
65976597
raw_ostream &O;
65986598
PrintingPolicy &Policy;
65996599
bool NSInserted = false;
6600+
ASTContext &Context;
66006601

66016602
public:
6602-
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy)
6603-
: O(O), Policy(PrintPolicy) {}
6603+
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy,
6604+
ASTContext &Context)
6605+
: O(O), Policy(PrintPolicy), Context(Context) {}
66046606

66056607
/// Emits the function declaration of template free function.
66066608
/// \param FTD The function declaration to print.
@@ -6797,18 +6799,42 @@ class FreeFunctionPrinter {
67976799
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
67986800
ParmListOstream << "<";
67996801

6800-
auto SpecArgs = TST->template_arguments();
6801-
auto DeclArgs = CTST->template_arguments();
6802+
ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
6803+
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();
6804+
6805+
auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
6806+
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
6807+
Arg.isInstantiationDependent()) {
6808+
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6809+
return;
6810+
}
6811+
6812+
Expr *E = Arg.getAsExpr();
6813+
assert(E && "Failed to get an Expr for an Expression template arg?");
6814+
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
6815+
// Scoped enumerations can't be implicitly cast from integers, so
6816+
// we don't need to evaluate them.
6817+
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6818+
return;
6819+
}
6820+
6821+
Expr::EvalResult Res;
6822+
[[maybe_unused]] bool Success =
6823+
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
6824+
assert(Success && "invalid non-type template argument?");
6825+
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
6826+
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
6827+
&Context);
6828+
};
68026829

68036830
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
68046831
SE = SpecArgs.size();
68056832
I < E; ++I) {
68066833
if (I != 0)
68076834
ParmListOstream << ", ";
6808-
if (I < SE) // A specialized argument exists, use it
6809-
SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6810-
else // Print a canonical form of a default argument
6811-
DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6835+
// If we have a specialized argument, use it. Otherwise fallback to a
6836+
// default argument.
6837+
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
68126838
}
68136839

68146840
ParmListOstream << ">";
@@ -7207,7 +7233,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
72077233
// template arguments that match default template arguments while printing
72087234
// template-ids, even if the source code doesn't reference them.
72097235
Policy.EnforceDefaultTemplateArgs = true;
7210-
FreeFunctionPrinter FFPrinter(O, Policy);
7236+
FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext());
72117237
if (FTD) {
72127238
FFPrinter.printFreeFunctionDeclaration(FTD);
72137239
if (const auto kind = K.SyclKernel->getTemplateSpecializationKind();
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
2+
// RUN: FileCheck -input-file=%t.h %s
3+
//
4+
// The purpose of this test is to ensure that forward declarations of free
5+
// function kernels are emitted properly.
6+
// However, this test checks a specific scenario:
7+
// - free function argument is a template which accepts constant expressions as
8+
// arguments
9+
10+
constexpr int A = 2;
11+
constexpr int B = 3;
12+
13+
namespace ns {
14+
15+
constexpr int C = 4;
16+
17+
struct Foo {
18+
static constexpr int D = 5;
19+
};
20+
21+
enum non_class_enum {
22+
VAL_A,
23+
VAL_B
24+
};
25+
26+
enum class class_enum {
27+
VAL_A,
28+
VAL_B
29+
};
30+
31+
enum non_class_enum_typed : int {
32+
VAL_C,
33+
VAL_D
34+
};
35+
36+
enum class class_enum_typed : int {
37+
VAL_C,
38+
VAL_D
39+
};
40+
41+
constexpr int bar(int arg) {
42+
return arg + 42;
43+
}
44+
45+
} // namespace ns
46+
47+
template<int V>
48+
struct Arg {};
49+
50+
template<ns::non_class_enum V>
51+
struct Arg2 {};
52+
53+
template<ns::non_class_enum_typed V>
54+
struct Arg3 {};
55+
56+
template<ns::class_enum V>
57+
struct Arg4 {};
58+
59+
template<ns::class_enum_typed V>
60+
struct Arg5 {};
61+
62+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
63+
void constant(Arg<1>) {}
64+
65+
// CHECK: void constant(Arg<1> );
66+
67+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
68+
void constexpr_v(Arg<A>) {}
69+
70+
// CHECK: void constexpr_v(Arg<2> );
71+
72+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
73+
void constexpr_expr(Arg<A * B>) {}
74+
75+
// CHECK: void constexpr_expr(Arg<6> );
76+
77+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
78+
void constexpr_ns(Arg<ns::C>) {}
79+
80+
// CHECK: void constexpr_ns(Arg<4> );
81+
82+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
83+
void constexpr_ns2(Arg<ns::Foo::D>) {}
84+
85+
// CHECK: void constexpr_ns2(Arg<5> );
86+
87+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
88+
void constexpr_ns2(Arg2<ns::non_class_enum::VAL_A>) {}
89+
90+
// CHECK: void constexpr_ns2(Arg2<ns::VAL_A> );
91+
92+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
93+
void constexpr_ns2(Arg3<ns::non_class_enum_typed::VAL_C>) {}
94+
95+
// CHECK: void constexpr_ns2(Arg3<ns::VAL_C> );
96+
97+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
98+
void constexpr_ns2(Arg4<ns::class_enum::VAL_A>) {}
99+
100+
// CHECK: void constexpr_ns2(Arg4<ns::class_enum::VAL_A> );
101+
102+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
103+
void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C>) {}
104+
105+
// CHECK: void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C> );
106+
107+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
108+
void constexpr_call(Arg<ns::bar(B)>) {}
109+
110+
// CHECK: void constexpr_call(Arg<45> );

0 commit comments

Comments
 (0)