Skip to content

Commit c3cdacb

Browse files
committed
[SYCL][clang] Emit default template arguments in integration header
For free function kernels support clang forward declares the kernel itself as well as its parameter types. In case a free function kernel has a parameter that is templated and has a default template argument, all template arguments including arguments that match default arguments must be printed in kernel's forward declarations, for example ``` template <typename T, typename = int> struct Arg { T val; }; // For the kernel SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( (ext::oneapi::experimental::nd_range_kernel<1>)) void foo(Arg<int> arg) { arg.val = 42; } // Integration header must contain void foo(Arg<int, int> arg); ``` Unfortunately, even though integration header emission already has extensive support for forward declarations priting, some modifications to clang's type printing are still required. infrastructure, since neither of existing PrintingPolicy flags help to reach the correct result. Using `SuppressDefaultTemplateArgs = true` doesn't help without printing canonical types, printing canonical types for the case like ``` template <typename T> SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( (ext::oneapi::experimental::nd_range_kernel<1>)) void foo(Arg<T> arg) { arg.val = 42; } // Printing canonical types is causing the following integration header template <typename T> void foo(Arg<type-parameter-0-0, int> arg); ``` Using `SkipCanonicalizationOfTemplateTypeParms` field of printing policy doesn't help here since at the one point where it is checked we take canonical type of `Arg`, not its parameters and it will contain template argument types in canonical type after that.
1 parent 4a274fc commit c3cdacb

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

clang/include/clang/AST/PrettyPrinter.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ struct PrintingPolicy {
6666
SuppressLifetimeQualifiers(false), SuppressTypedefs(false),
6767
SuppressFinalSpecifier(false),
6868
SuppressTemplateArgsInCXXConstructors(false),
69-
SuppressDefaultTemplateArgs(true), Bool(LO.Bool),
70-
Nullptr(LO.CPlusPlus11 || LO.C23), NullptrTypeInNamespace(LO.CPlusPlus),
71-
Restrict(LO.C99), Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
69+
SuppressDefaultTemplateArgs(true), EnforceDefaultTemplateArgs(false),
70+
Bool(LO.Bool), Nullptr(LO.CPlusPlus11 || LO.C23),
71+
NullptrTypeInNamespace(LO.CPlusPlus), Restrict(LO.C99),
72+
Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
7273
UseVoidForZeroParams(!LO.CPlusPlus),
7374
SplitTemplateClosers(!LO.CPlusPlus11), TerseOutput(false),
7475
PolishForDeclaration(false), Half(LO.Half),
@@ -237,6 +238,12 @@ struct PrintingPolicy {
237238
LLVM_PREFERRED_TYPE(bool)
238239
unsigned SuppressDefaultTemplateArgs : 1;
239240

241+
242+
/// When true, print template arguments that match the default argument for
243+
/// the parameter, even if they're not specified in the source.
244+
LLVM_PREFERRED_TYPE(bool)
245+
unsigned EnforceDefaultTemplateArgs : 1;
246+
240247
/// Whether we can use 'bool' rather than '_Bool' (even if the language
241248
/// doesn't actually have 'bool', because, e.g., it is defined as a macro).
242249
LLVM_PREFERRED_TYPE(bool)

clang/lib/AST/TypePrinter.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,15 +2397,35 @@ template <typename TA>
23972397
static void
23982398
printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,
23992399
const TemplateParameterList *TPL, bool IsPack, unsigned ParmIndex) {
2400-
// Drop trailing template arguments that match default arguments.
2401-
if (TPL && Policy.SuppressDefaultTemplateArgs &&
2402-
!Policy.PrintCanonicalTypes && !Args.empty() && !IsPack &&
2400+
llvm::SmallVector<TemplateArgument, 8> OrigArgs;
2401+
for (const TA &A : Args)
2402+
OrigArgs.push_back(getArgument(A));
2403+
if (TPL && !Policy.PrintCanonicalTypes && !IsPack &&
24032404
Args.size() <= TPL->size()) {
2404-
llvm::SmallVector<TemplateArgument, 8> OrigArgs;
2405-
for (const TA &A : Args)
2406-
OrigArgs.push_back(getArgument(A));
2407-
while (!Args.empty() && getArgument(Args.back()).getIsDefaulted())
2408-
Args = Args.drop_back();
2405+
// Drop trailing template arguments that match default arguments.
2406+
if (Policy.SuppressDefaultTemplateArgs) {
2407+
while (!OrigArgs.empty() && getArgument(OrigArgs.back()).getIsDefaulted())
2408+
OrigArgs.pop_back();
2409+
} else if (Policy.EnforceDefaultTemplateArgs) {
2410+
for (unsigned I = Args.size(); I < TPL->size(); ++I) {
2411+
auto Param = TPL->getParam(I);
2412+
if (auto *TTPD = dyn_cast<TemplateTypeParmDecl>(Param)) {
2413+
if (!TTPD->hasDefaultArgument())
2414+
break;
2415+
OrigArgs.push_back(getArgument(TTPD->getDefaultArgument()));
2416+
} else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Param)) {
2417+
if (!TTPD->hasDefaultArgument())
2418+
break;
2419+
OrigArgs.push_back(getArgument(TTPD->getDefaultArgument()));
2420+
} else if (auto *NTTPD = dyn_cast<NonTypeTemplateParmDecl>(Param)) {
2421+
if (!NTTPD->hasDefaultArgument())
2422+
break;
2423+
OrigArgs.push_back(getArgument(NTTPD->getDefaultArgument()));
2424+
} else {
2425+
llvm_unreachable("unexpected template parameter");
2426+
}
2427+
}
2428+
}
24092429
}
24102430

24112431
const char *Comma = Policy.MSVCFormatting ? "," : ", ";
@@ -2414,7 +2434,7 @@ printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,
24142434

24152435
bool NeedSpace = false;
24162436
bool FirstArg = true;
2417-
for (const auto &Arg : Args) {
2437+
for (const auto &Arg : OrigArgs) {
24182438
// Print the argument into a string.
24192439
SmallString<128> Buf;
24202440
llvm::raw_svector_ostream ArgOS(Buf);

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6464,6 +6464,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
64646464

64656465
O << "\n";
64666466
O << "// Forward declarations of kernel and its argument types:\n";
6467+
Policy.SuppressDefaultTemplateArgs = false;
64676468
FwdDeclEmitter.Visit(K.SyclKernel->getType());
64686469
O << "\n";
64696470

@@ -6476,11 +6477,12 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
64766477
FirstParam = false;
64776478
else
64786479
ParmList += ", ";
6479-
ParmList += Param->getType().getCanonicalType().getAsString();
6480+
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
64806481
}
64816482
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
64826483
Policy.SuppressDefinition = true;
64836484
Policy.PolishForDeclaration = true;
6485+
Policy.EnforceDefaultTemplateArgs = true;
64846486
if (FTD) {
64856487
FTD->print(O, Policy);
64866488
} else {
@@ -6509,6 +6511,8 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65096511
}
65106512
O << ";\n";
65116513
O << "}\n";
6514+
Policy.SuppressDefaultTemplateArgs = true;
6515+
Policy.EnforceDefaultTemplateArgs = false;
65126516

65136517
// Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
65146518
O << "namespace sycl {\n";

0 commit comments

Comments
 (0)