diff --git a/clang/include/clang/AST/PrettyPrinter.h b/clang/include/clang/AST/PrettyPrinter.h index a50216615c4a9..beadfea6e3ae1 100644 --- a/clang/include/clang/AST/PrettyPrinter.h +++ b/clang/include/clang/AST/PrettyPrinter.h @@ -68,17 +68,18 @@ struct PrintingPolicy { SuppressStrongLifetime(false), SuppressLifetimeQualifiers(false), SuppressTypedefs(false), SuppressFinalSpecifier(false), SuppressTemplateArgsInCXXConstructors(false), - SuppressDefaultTemplateArgs(true), Bool(LO.Bool), - Nullptr(LO.CPlusPlus11 || LO.C23), NullptrTypeInNamespace(LO.CPlusPlus), - Restrict(LO.C99), Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11), + SuppressDefaultTemplateArgs(true), EnforceDefaultTemplateArgs(false), + Bool(LO.Bool), Nullptr(LO.CPlusPlus11 || LO.C23), + NullptrTypeInNamespace(LO.CPlusPlus), Restrict(LO.C99), + Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11), UseVoidForZeroParams(!LO.CPlusPlus), SplitTemplateClosers(!LO.CPlusPlus11), TerseOutput(false), PolishForDeclaration(false), Half(LO.Half), MSWChar(LO.MicrosoftExt && !LO.WChar), IncludeNewlines(true), MSVCFormatting(false), ConstantsAsWritten(false), SuppressImplicitBase(false), FullyQualifiedName(false), - SuppressDefinition(false), SuppressDefaultTemplateArguments(false), - PrintCanonicalTypes(false), + EnforceScopeForElaboratedTypes(false), SuppressDefinition(false), + SuppressDefaultTemplateArguments(false), PrintCanonicalTypes(false), SkipCanonicalizationOfTemplateTypeParms(false), PrintInjectedClassNameWithArguments(true), UsePreferredNames(true), AlwaysIncludeTypeForTemplateArgument(false), @@ -241,6 +242,11 @@ struct PrintingPolicy { LLVM_PREFERRED_TYPE(bool) unsigned SuppressDefaultTemplateArgs : 1; + /// When true, print template arguments that match the default argument for + /// the parameter, even if they're not specified in the source. + LLVM_PREFERRED_TYPE(bool) + unsigned EnforceDefaultTemplateArgs : 1; + /// Whether we can use 'bool' rather than '_Bool' (even if the language /// doesn't actually have 'bool', because, e.g., it is defined as a macro). LLVM_PREFERRED_TYPE(bool) @@ -339,6 +345,10 @@ struct PrintingPolicy { LLVM_PREFERRED_TYPE(bool) unsigned FullyQualifiedName : 1; + /// Enforce fully qualified name printing for elaborated types. + LLVM_PREFERRED_TYPE(bool) + unsigned EnforceScopeForElaboratedTypes : 1; + /// When true does not print definition of a type. E.g. /// \code /// template class C0 : public C1 {...} diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index 636ddaddf8769..49eb096cf369f 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -101,7 +101,7 @@ class ElaboratedTypePolicyRAII { SuppressTagKeyword = Policy.SuppressTagKeyword; SuppressScope = Policy.SuppressScope; Policy.SuppressTagKeyword = true; - Policy.SuppressScope = true; + Policy.SuppressScope = !Policy.EnforceScopeForElaboratedTypes; } ~ElaboratedTypePolicyRAII() { @@ -1728,8 +1728,10 @@ void TypePrinter::printElaboratedBefore(const ElaboratedType *T, Policy.SuppressScope = OldSupressScope; return; } - if (Qualifier && !(Policy.SuppressTypedefs && - T->getNamedType()->getTypeClass() == Type::Typedef)) + if (Qualifier && + !(Policy.SuppressTypedefs && + T->getNamedType()->getTypeClass() == Type::Typedef) && + !Policy.EnforceScopeForElaboratedTypes) Qualifier->print(OS, Policy); } @@ -2220,15 +2222,6 @@ static void printArgument(const TemplateArgument &A, const PrintingPolicy &PP, A.print(PP, OS, IncludeType); } -static void printArgument(const TemplateArgumentLoc &A, - const PrintingPolicy &PP, llvm::raw_ostream &OS, - bool IncludeType) { - const TemplateArgument::ArgKind &Kind = A.getArgument().getKind(); - if (Kind == TemplateArgument::ArgKind::Type) - return A.getTypeSourceInfo()->getType().print(OS, PP); - return A.getArgument().print(PP, OS, IncludeType); -} - static bool isSubstitutedTemplateArgument(ASTContext &Ctx, TemplateArgument Arg, TemplateArgument Pattern, ArrayRef Args, @@ -2399,15 +2392,40 @@ template static void printTo(raw_ostream &OS, ArrayRef Args, const PrintingPolicy &Policy, const TemplateParameterList *TPL, bool IsPack, unsigned ParmIndex) { - // Drop trailing template arguments that match default arguments. - if (TPL && Policy.SuppressDefaultTemplateArgs && - !Policy.PrintCanonicalTypes && !Args.empty() && !IsPack && + llvm::SmallVector ArgsToPrint; + for (const TA &A : Args) + ArgsToPrint.push_back(getArgument(A)); + if (TPL && !Policy.PrintCanonicalTypes && !IsPack && Args.size() <= TPL->size()) { - llvm::SmallVector OrigArgs; - for (const TA &A : Args) - OrigArgs.push_back(getArgument(A)); - while (!Args.empty() && getArgument(Args.back()).getIsDefaulted()) - Args = Args.drop_back(); + // Drop trailing template arguments that match default arguments. + if (Policy.SuppressDefaultTemplateArgs) { + while (!ArgsToPrint.empty() && + getArgument(ArgsToPrint.back()).getIsDefaulted()) + ArgsToPrint.pop_back(); + } else if (Policy.EnforceDefaultTemplateArgs) { + for (unsigned I = Args.size(); I < TPL->size(); ++I) { + auto Param = TPL->getParam(I); + if (auto *TTPD = dyn_cast(Param)) { + // If we met a non default-argument past provided list of arguments, + // it is either a pack which must be the last arguments, or provided + // argument list was problematic. Bail out either way. Do the same + // for each kind of template argument. + if (!TTPD->hasDefaultArgument()) + break; + ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument())); + } else if (auto *TTPD = dyn_cast(Param)) { + if (!TTPD->hasDefaultArgument()) + break; + ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument())); + } else if (auto *NTTPD = dyn_cast(Param)) { + if (!NTTPD->hasDefaultArgument()) + break; + ArgsToPrint.push_back(getArgument(NTTPD->getDefaultArgument())); + } else { + llvm_unreachable("unexpected template parameter"); + } + } + } } const char *Comma = Policy.MSVCFormatting ? "," : ", "; @@ -2416,7 +2434,7 @@ printTo(raw_ostream &OS, ArrayRef Args, const PrintingPolicy &Policy, bool NeedSpace = false; bool FirstArg = true; - for (const auto &Arg : Args) { + for (const auto &Arg : ArgsToPrint) { // Print the argument into a string. SmallString<128> Buf; llvm::raw_svector_ostream ArgOS(Buf); diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index e45b038273d77..9b1e8d6c43fdf 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6483,16 +6483,46 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { O << "extern \"C\" "; std::string ParmList; bool FirstParam = true; + Policy.SuppressDefaultTemplateArgs = false; for (ParmVarDecl *Param : K.SyclKernel->parameters()) { if (FirstParam) FirstParam = false; else ParmList += ", "; - ParmList += Param->getType().getCanonicalType().getAsString(); + ParmList += Param->getType().getCanonicalType().getAsString(Policy); } FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate(); Policy.SuppressDefinition = true; Policy.PolishForDeclaration = true; + Policy.FullyQualifiedName = true; + Policy.EnforceScopeForElaboratedTypes = true; + + // Now we need to print the declaration of the kernel itself. + // Example: + // template struct Arg { + // T val; + // }; + // For the following free function kernel: + // template + // SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( + // (ext::oneapi::experimental::nd_range_kernel<1>)) + // void foo(Arg arg) {} + // Integration header must contain the following declaration: + // template + // void foo(Arg arg); + // SuppressDefaultTemplateArguments is a downstream addition that suppresses + // default template arguments in the function declaration. It should be set + // to true to emit function declaration that won't cause any compilation + // errors when present in the integration header. + // To print Arg in the function declaration and shim functions we + // need to disable default arguments printing suppression via community flag + // SuppressDefaultTemplateArgs, otherwise they will be suppressed even for + // canonical types or if even written in the original source code. + Policy.SuppressDefaultTemplateArguments = true; + // EnforceDefaultTemplateArgs is a downstream addition that forces printing + // template arguments that match default template arguments while printing + // template-ids, even if the source code doesn't reference them. + Policy.EnforceDefaultTemplateArgs = true; if (FTD) { FTD->print(O, Policy); } else { diff --git a/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp b/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp new file mode 100644 index 0000000000000..808f7b93d8112 --- /dev/null +++ b/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp @@ -0,0 +1,100 @@ +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s +// RUN: FileCheck -input-file=%t.h %s + +// This test checks integration header contents for free functions kernels with +// parameter types that have default template arguments. + +#include "mock_properties.hpp" +#include "sycl.hpp" + +namespace ns { + +struct notatuple { + int a; +}; + +namespace ns1 { +template +class hasDefaultArg { + +}; +} + +template struct Arg { + T val; +}; + +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", + 2)]] void +simple(Arg){ +} + +} + +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", + 2)]] void +simple1(ns::Arg>){ +} + + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void +templated(ns::Arg, T end) { +} + +template void templated(ns::Arg, int); + +using namespace ns; + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void +templated2(Arg, T end) { +} + +template void templated2(Arg, int); + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void +templated3(Arg, int, int>, T end) { +} + +template void templated3(Arg, int, int>, int); + +// CHECK: Forward declarations of kernel and its argument types: +// CHECK-NEXT: namespace ns { +// CHECK-NEXT: struct notatuple; +// CHECK-NEXT: } +// CHECK-NEXT: namespace ns { +// CHECK-NEXT: template struct Arg; +// CHECK-NEXT: } + +// CHECK: void ns::simple(ns::Arg); +// CHECK-NEXT: static constexpr auto __sycl_shim1() { +// CHECK-NEXT: return (void (*)(struct ns::Arg))simple; +// CHECK-NEXT: } + +// CHECK: Forward declarations of kernel and its argument types: +// CHECK: namespace ns { +// CHECK: namespace ns1 { +// CHECK-NEXT: template class hasDefaultArg; +// CHECK-NEXT: } + +// CHECK: void simple1(ns::Arg, int, 12, ns::notatuple>); +// CHECK-NEXT: static constexpr auto __sycl_shim2() { +// CHECK-NEXT: return (void (*)(struct ns::Arg, int, 12, struct ns::notatuple>))simple1; +// CHECK-NEXT: } + +// CHECK: template void templated(ns::Arg, T end); +// CHECK-NEXT: static constexpr auto __sycl_shim3() { +// CHECK-NEXT: return (void (*)(struct ns::Arg, int))templated; +// CHECK-NEXT: } + +// CHECK: template void templated2(ns::Arg, T end); +// CHECK-NEXT: static constexpr auto __sycl_shim4() { +// CHECK-NEXT: return (void (*)(struct ns::Arg, int))templated2; +// CHECK-NEXT: } + +// CHECK: template void templated3(ns::Arg, int, int>, T end); +// CHECK-NEXT: static constexpr auto __sycl_shim5() { +// CHECK-NEXT: return (void (*)(struct ns::Arg, int, int>, int))templated3; +// CHECK-NEXT: }