Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions clang/include/clang/AST/PrettyPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,18 @@ struct PrintingPolicy {
SuppressStrongLifetime(false), SuppressLifetimeQualifiers(false),
SuppressTypedefs(false), SuppressFinalSpecifier(false),
SuppressTemplateArgsInCXXConstructors(false),
SuppressDefaultTemplateArgs(true), Bool(LO.Bool),
Nullptr(LO.CPlusPlus11 || LO.C23), NullptrTypeInNamespace(LO.CPlusPlus),
Restrict(LO.C99), Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
SuppressDefaultTemplateArgs(true), EnforceDefaultTemplateArgs(false),
Bool(LO.Bool), Nullptr(LO.CPlusPlus11 || LO.C23),
NullptrTypeInNamespace(LO.CPlusPlus), Restrict(LO.C99),
Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
UseVoidForZeroParams(!LO.CPlusPlus),
SplitTemplateClosers(!LO.CPlusPlus11), TerseOutput(false),
PolishForDeclaration(false), Half(LO.Half),
MSWChar(LO.MicrosoftExt && !LO.WChar), IncludeNewlines(true),
MSVCFormatting(false), ConstantsAsWritten(false),
SuppressImplicitBase(false), FullyQualifiedName(false),
SuppressDefinition(false), SuppressDefaultTemplateArguments(false),
PrintCanonicalTypes(false),
EnforceScopeForElaboratedTypes(false), SuppressDefinition(false),
SuppressDefaultTemplateArguments(false), PrintCanonicalTypes(false),
SkipCanonicalizationOfTemplateTypeParms(false),
PrintInjectedClassNameWithArguments(true), UsePreferredNames(true),
AlwaysIncludeTypeForTemplateArgument(false),
Expand Down Expand Up @@ -241,6 +242,11 @@ struct PrintingPolicy {
LLVM_PREFERRED_TYPE(bool)
unsigned SuppressDefaultTemplateArgs : 1;

/// When true, print template arguments that match the default argument for
/// the parameter, even if they're not specified in the source.
LLVM_PREFERRED_TYPE(bool)
unsigned EnforceDefaultTemplateArgs : 1;

/// Whether we can use 'bool' rather than '_Bool' (even if the language
/// doesn't actually have 'bool', because, e.g., it is defined as a macro).
LLVM_PREFERRED_TYPE(bool)
Expand Down Expand Up @@ -339,6 +345,10 @@ struct PrintingPolicy {
LLVM_PREFERRED_TYPE(bool)
unsigned FullyQualifiedName : 1;

/// Enforce fully qualified name printing for elaborated types.
LLVM_PREFERRED_TYPE(bool)
unsigned EnforceScopeForElaboratedTypes : 1;

/// When true does not print definition of a type. E.g.
/// \code
/// template<typename T> class C0 : public C1 {...}
Expand Down
60 changes: 39 additions & 21 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ElaboratedTypePolicyRAII {
SuppressTagKeyword = Policy.SuppressTagKeyword;
SuppressScope = Policy.SuppressScope;
Policy.SuppressTagKeyword = true;
Policy.SuppressScope = true;
Policy.SuppressScope = !Policy.EnforceScopeForElaboratedTypes;
}

~ElaboratedTypePolicyRAII() {
Expand Down Expand Up @@ -1728,8 +1728,10 @@ void TypePrinter::printElaboratedBefore(const ElaboratedType *T,
Policy.SuppressScope = OldSupressScope;
return;
}
if (Qualifier && !(Policy.SuppressTypedefs &&
T->getNamedType()->getTypeClass() == Type::Typedef))
if (Qualifier &&
!(Policy.SuppressTypedefs &&
T->getNamedType()->getTypeClass() == Type::Typedef) &&
!Policy.EnforceScopeForElaboratedTypes)
Qualifier->print(OS, Policy);
}

Expand Down Expand Up @@ -2220,15 +2222,6 @@ static void printArgument(const TemplateArgument &A, const PrintingPolicy &PP,
A.print(PP, OS, IncludeType);
}

static void printArgument(const TemplateArgumentLoc &A,
const PrintingPolicy &PP, llvm::raw_ostream &OS,
bool IncludeType) {
const TemplateArgument::ArgKind &Kind = A.getArgument().getKind();
if (Kind == TemplateArgument::ArgKind::Type)
return A.getTypeSourceInfo()->getType().print(OS, PP);
return A.getArgument().print(PP, OS, IncludeType);
}

static bool isSubstitutedTemplateArgument(ASTContext &Ctx, TemplateArgument Arg,
TemplateArgument Pattern,
ArrayRef<TemplateArgument> Args,
Expand Down Expand Up @@ -2399,15 +2392,40 @@ template <typename TA>
static void
printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,
const TemplateParameterList *TPL, bool IsPack, unsigned ParmIndex) {
// Drop trailing template arguments that match default arguments.
if (TPL && Policy.SuppressDefaultTemplateArgs &&
!Policy.PrintCanonicalTypes && !Args.empty() && !IsPack &&
llvm::SmallVector<TemplateArgument, 8> ArgsToPrint;
for (const TA &A : Args)
ArgsToPrint.push_back(getArgument(A));
if (TPL && !Policy.PrintCanonicalTypes && !IsPack &&
Args.size() <= TPL->size()) {
llvm::SmallVector<TemplateArgument, 8> OrigArgs;
for (const TA &A : Args)
OrigArgs.push_back(getArgument(A));
while (!Args.empty() && getArgument(Args.back()).getIsDefaulted())
Args = Args.drop_back();
// Drop trailing template arguments that match default arguments.
if (Policy.SuppressDefaultTemplateArgs) {
while (!ArgsToPrint.empty() &&
getArgument(ArgsToPrint.back()).getIsDefaulted())
ArgsToPrint.pop_back();
} else if (Policy.EnforceDefaultTemplateArgs) {
for (unsigned I = Args.size(); I < TPL->size(); ++I) {
auto Param = TPL->getParam(I);
if (auto *TTPD = dyn_cast<TemplateTypeParmDecl>(Param)) {
// If we met a non default-argument past provided list of arguments,
// it is either a pack which must be the last arguments, or provided
// argument list was problematic. Bail out either way. Do the same
// for each kind of template argument.
if (!TTPD->hasDefaultArgument())
break;
ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument()));
} else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Param)) {
if (!TTPD->hasDefaultArgument())
break;
ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument()));
} else if (auto *NTTPD = dyn_cast<NonTypeTemplateParmDecl>(Param)) {
if (!NTTPD->hasDefaultArgument())
break;
ArgsToPrint.push_back(getArgument(NTTPD->getDefaultArgument()));
} else {
llvm_unreachable("unexpected template parameter");
}
}
}
}

const char *Comma = Policy.MSVCFormatting ? "," : ", ";
Expand All @@ -2416,7 +2434,7 @@ printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,

bool NeedSpace = false;
bool FirstArg = true;
for (const auto &Arg : Args) {
for (const auto &Arg : ArgsToPrint) {
// Print the argument into a string.
SmallString<128> Buf;
llvm::raw_svector_ostream ArgOS(Buf);
Expand Down
32 changes: 31 additions & 1 deletion clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6483,16 +6483,46 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << "extern \"C\" ";
std::string ParmList;
bool FirstParam = true;
Policy.SuppressDefaultTemplateArgs = false;
for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
if (FirstParam)
FirstParam = false;
else
ParmList += ", ";
ParmList += Param->getType().getCanonicalType().getAsString();
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
}
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
Policy.SuppressDefinition = true;
Policy.PolishForDeclaration = true;
Policy.FullyQualifiedName = true;
Policy.EnforceScopeForElaboratedTypes = true;

// Now we need to print the declaration of the kernel itself.
// Example:
// template <typename T, typename = int> struct Arg {
// T val;
// };
// For the following free function kernel:
// template <typename = T>
// SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
// (ext::oneapi::experimental::nd_range_kernel<1>))
// void foo(Arg<int> arg) {}
// Integration header must contain the following declaration:
// template <typename>
// void foo(Arg<int, int> arg);
// SuppressDefaultTemplateArguments is a downstream addition that suppresses
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an example in the comments would be helpful here. Probably the same one in PR description. Just reading these is very confusing without context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks!

// default template arguments in the function declaration. It should be set
// to true to emit function declaration that won't cause any compilation
// errors when present in the integration header.
// To print Arg<int, int> in the function declaration and shim functions we
// need to disable default arguments printing suppression via community flag
// SuppressDefaultTemplateArgs, otherwise they will be suppressed even for
// canonical types or if even written in the original source code.
Policy.SuppressDefaultTemplateArguments = true;
// EnforceDefaultTemplateArgs is a downstream addition that forces printing
// template arguments that match default template arguments while printing
// template-ids, even if the source code doesn't reference them.
Policy.EnforceDefaultTemplateArgs = true;
if (FTD) {
FTD->print(O, Policy);
} else {
Expand Down
100 changes: 100 additions & 0 deletions clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
// RUN: FileCheck -input-file=%t.h %s

// This test checks integration header contents for free functions kernels with
// parameter types that have default template arguments.

#include "mock_properties.hpp"
#include "sycl.hpp"

namespace ns {

struct notatuple {
int a;
};

namespace ns1 {
template <typename A = notatuple>
class hasDefaultArg {

};
}

template <typename T, typename = int, int a = 12, typename = notatuple, typename ...TS> struct Arg {
T val;
};

[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel",
2)]] void
simple(Arg<char>){
}

}

[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel",
2)]] void
simple1(ns::Arg<ns::ns1::hasDefaultArg<>>){
}


template <typename T>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
templated(ns::Arg<T, float, 3>, T end) {
}

template void templated(ns::Arg<int, float, 3>, int);

using namespace ns;

template <typename T>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
templated2(Arg<T, notatuple>, T end) {
}

template void templated2(Arg<int, notatuple>, int);

template <typename T, int a = 3>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
templated3(Arg<T, notatuple, a, ns1::hasDefaultArg<>, int, int>, T end) {
}

template void templated3(Arg<int, notatuple, 3, ns1::hasDefaultArg<>, int, int>, int);

// CHECK: Forward declarations of kernel and its argument types:
// CHECK-NEXT: namespace ns {
// CHECK-NEXT: struct notatuple;
// CHECK-NEXT: }
// CHECK-NEXT: namespace ns {
// CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg;
// CHECK-NEXT: }

// CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>);
// CHECK-NEXT: static constexpr auto __sycl_shim1() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple;
// CHECK-NEXT: }

// CHECK: Forward declarations of kernel and its argument types:
// CHECK: namespace ns {
// CHECK: namespace ns1 {
// CHECK-NEXT: template <typename A> class hasDefaultArg;
// CHECK-NEXT: }

// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>);
// CHECK-NEXT: static constexpr auto __sycl_shim2() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1;
// CHECK-NEXT: }

// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple>, T end);
// CHECK-NEXT: static constexpr auto __sycl_shim3() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, float, 3, struct ns::notatuple>, int))templated<int>;
// CHECK-NEXT: }

// CHECK: template <typename T> void templated2(ns::Arg<T, ns::notatuple, 12, ns::notatuple>, T end);
// CHECK-NEXT: static constexpr auto __sycl_shim4() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 12, struct ns::notatuple>, int))templated2<int>;
// CHECK-NEXT: }

// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int>, T end);
// CHECK-NEXT: static constexpr auto __sycl_shim5() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 3, class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, int>, int))templated3<int, 3>;
// CHECK-NEXT: }
Loading