Skip to content

Commit 5176ed8

Browse files
[SYCL] Free function kernels bugfix (#19535)
When free function kernel is a function template which in turn has arguments that are also templates, we may not print the latter properly into integration header. Specifically, if an argument of a templated free function kernel has an `enum` template argument of its own then namespace qualifiers for it will be missing. This patch fixed that.
1 parent 94c70cf commit 5176ed8

File tree

6 files changed

+159
-56
lines changed

6 files changed

+159
-56
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "clang/AST/SYCLKernelInfo.h"
1919
#include "clang/AST/StmtSYCL.h"
2020
#include "clang/AST/TemplateArgumentVisitor.h"
21+
#include "clang/AST/Type.h"
2122
#include "clang/AST/TypeOrdering.h"
2223
#include "clang/AST/TypeVisitor.h"
2324
#include "clang/Analysis/CallGraph.h"
@@ -6738,23 +6739,75 @@ class FreeFunctionPrinter {
67386739
/// returned string Example:
67396740
/// \code
67406741
/// template <typename T1, typename T2>
6741-
/// void foo(T1 a, T2 b);
6742+
/// void foo(T1 a, int b, T2 c);
67426743
/// \endcode
6743-
/// returns string "T1 a, T2 b"
6744+
/// returns string "T1, int, T2"
67446745
std::string
67456746
getTemplatedParamList(const llvm::ArrayRef<clang::ParmVarDecl *> Parameters,
67466747
PrintingPolicy Policy) {
67476748
bool FirstParam = true;
67486749
llvm::SmallString<128> ParamList;
67496750
llvm::raw_svector_ostream ParmListOstream{ParamList};
67506751
Policy.SuppressTagKeyword = true;
6752+
67516753
for (ParmVarDecl *Param : Parameters) {
67526754
if (FirstParam)
67536755
FirstParam = false;
67546756
else
67556757
ParmListOstream << ", ";
6756-
ParmListOstream << Param->getType().getAsString(Policy);
6757-
ParmListOstream << " " << Param->getNameAsString();
6758+
6759+
// There are cases when we can't directly use neither the original
6760+
// argument type, nor its canonical version. An example would be:
6761+
// template<typename T>
6762+
// void kernel(sycl::accessor<T, 1>);
6763+
// template void kernel(sycl::accessor<int, 1>);
6764+
// Accessor has multiple non-type template arguments with default values
6765+
// and non-qualified type will not include necessary namespaces for all
6766+
// of them. Qualified type will have that information, but all references
6767+
// to T will be replaced to something like type-argument-0
6768+
// What we do instead is we iterate template arguments of both versions
6769+
// of a type in sync and take elements from one or another to get the best
6770+
// of both: proper references to template arguments of a kernel itself and
6771+
// fully-qualified names for enumerations.
6772+
//
6773+
// Moral of the story: drop integration header ASAP (but that is blocked
6774+
// by support for 3rd-party host compilers, which is important).
6775+
QualType T = Param->getType();
6776+
QualType CT = T.getCanonicalType();
6777+
6778+
auto *ET = dyn_cast<ElaboratedType>(T.getTypePtr());
6779+
if (!ET) {
6780+
ParmListOstream << T.getAsString(Policy);
6781+
continue;
6782+
}
6783+
6784+
auto *TST =
6785+
dyn_cast<TemplateSpecializationType>(ET->getNamedType().getTypePtr());
6786+
auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr());
6787+
if (!TST || !CTST) {
6788+
ParmListOstream << T.getAsString(Policy);
6789+
continue;
6790+
}
6791+
6792+
TemplateName TN = TST->getTemplateName();
6793+
auto SpecArgs = TST->template_arguments();
6794+
auto DeclArgs = CTST->template_arguments();
6795+
6796+
TN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
6797+
ParmListOstream << "<";
6798+
6799+
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
6800+
SE = SpecArgs.size();
6801+
I < E; ++I) {
6802+
if (I != 0)
6803+
ParmListOstream << ", ";
6804+
if (I < SE) // A specialized argument exists, use it
6805+
SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6806+
else // Print a canonical form of a default argument
6807+
DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6808+
}
6809+
6810+
ParmListOstream << ">";
67586811
}
67596812
return ParamList.str().str();
67606813
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
2+
// RUN: FileCheck -input-file=%t.h %s
3+
//
4+
// The purpose of this test is to ensure that forward declarations of free
5+
// function kernels are emitted properly.
6+
// However, this test checks a specific scenario:
7+
// - free function kernel is a function template
8+
// - its argument is templated and has non-type template parameter (with default
9+
// value) that is an enumeration defined within a namespace
10+
11+
namespace ns {
12+
13+
enum class enum_A { A, B, C };
14+
15+
template<typename T, enum_A V = enum_A::B>
16+
class feature_A {};
17+
18+
namespace nested {
19+
enum class enum_B { A, B, C };
20+
21+
template<typename T, int V, enum_B V2 = enum_B::A, enum_A V3 = enum_A::C>
22+
struct feature_B {};
23+
}
24+
25+
inline namespace nested_inline {
26+
namespace nested2 {
27+
enum class enum_C { A, B, C };
28+
29+
template<int V = 42, enum_C V2 = enum_C::B>
30+
struct feature_C {};
31+
}
32+
} // namespace nested_inline
33+
} // namespace ns
34+
35+
template<typename T>
36+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
37+
void templated_on_A(ns::feature_A<T> Arg) {}
38+
template void templated_on_A(ns::feature_A<int>);
39+
40+
// CHECK: template <typename T> void templated_on_A(ns::feature_A<T, ns::enum_A::B>);
41+
42+
template<typename T, int V = 42>
43+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
44+
void templated_on_B(ns::nested::feature_B<T, V> Arg) {}
45+
template void templated_on_B(ns::nested::feature_B<int, 12>);
46+
47+
// CHECK: template <typename T, int V> void templated_on_B(ns::nested::feature_B<T, V, ns::nested::enum_B::A, ns::enum_A::C>);
48+
49+
template<int V>
50+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
51+
void templated_on_C(ns::nested2::feature_C<V> Arg) {}
52+
template void templated_on_C(ns::nested2::feature_C<42>);
53+
54+
// CHECK: template <int V> void templated_on_C(ns::nested2::feature_C<V, ns::nested2::enum_C::B>);

clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ namespace Testing::Tests {
295295
// CHECK-NEXT: } // namespace _V1
296296
// CHECK-NEXT: } // namespace sycl
297297

298-
// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple> , T end);
298+
// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple, <>>, T);
299299
// CHECK-NEXT: static constexpr auto __sycl_shim3() {
300300
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, float, 3, struct ns::notatuple>, int))templated<int>;
301301
// CHECK-NEXT: }
@@ -314,7 +314,7 @@ namespace Testing::Tests {
314314
// CHECK-NEXT: } // namespace _V1
315315
// CHECK-NEXT: } // namespace sycl
316316

317-
// CHECK: template <typename T> void templated2(ns::Arg<T, ns::notatuple, 12, ns::notatuple> , T end);
317+
// CHECK: template <typename T> void templated2(ns::Arg<T, ns::notatuple, 12, ns::notatuple, <>>, T);
318318
// CHECK-NEXT: static constexpr auto __sycl_shim4() {
319319
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 12, struct ns::notatuple>, int))templated2<int>;
320320
// CHECK-NEXT: }
@@ -333,7 +333,7 @@ namespace Testing::Tests {
333333
// CHECK-NEXT: } // namespace _V1
334334
// CHECK-NEXT: } // namespace sycl
335335

336-
// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int> , T end);
336+
// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int>, T);
337337
// CHECK-NEXT: static constexpr auto __sycl_shim5() {
338338
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 3, class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, int>, int))templated3<int, 3>;
339339
// CHECK-NEXT: }
@@ -352,7 +352,7 @@ namespace Testing::Tests {
352352
// CHECK-NEXT: } // namespace _V1
353353
// CHECK-NEXT: } // namespace sycl
354354

355-
// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int> , T end);
355+
// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int>, T);
356356
// CHECK-NEXT: static constexpr auto __sycl_shim6() {
357357
// CHECK-NEXT: return (void (*)(struct ns::Arg<float, struct ns::notatuple, 3, class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, int>, float))templated3<float, 3>;
358358
// CHECK-NEXT: }
@@ -400,7 +400,7 @@ namespace Testing::Tests {
400400
// CHECK-NEXT: } // namespace sycl
401401

402402
// CHECK: namespace TestNamespace {
403-
// CHECK-NEXT: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple> , T end);
403+
// CHECK-NEXT: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple, <>>, T);
404404
// CHECK-NEXT: } // namespace TestNamespace
405405

406406
// CHECK: static constexpr auto __sycl_shim8() {
@@ -434,7 +434,7 @@ namespace Testing::Tests {
434434

435435
// CHECK: namespace TestNamespace {
436436
// CHECK-NEXT: inline namespace _V1 {
437-
// CHECK-NEXT: template <typename T, int a> void templated1(ns::Arg<T, float, a, ns::notatuple> , T end);
437+
// CHECK-NEXT: template <typename T, int a> void templated1(ns::Arg<T, float, a, ns::notatuple, <>>, T);
438438
// CHECK-NEXT: } // inline namespace _V1
439439
// CHECK-NEXT: } // namespace TestNamespace
440440
// CHECK: static constexpr auto __sycl_shim9() {
@@ -468,7 +468,7 @@ namespace Testing::Tests {
468468

469469
// CHECK: namespace TestNamespace {
470470
// CHECK-NEXT: inline namespace _V2 {
471-
// CHECK-NEXT: template <typename T, int a> void templated1(ns::Arg<T, T, a, ns::notatuple> , T end);
471+
// CHECK-NEXT: template <typename T, int a> void templated1(ns::Arg<T, T, a, ns::notatuple, <>>, T);
472472
// CHECK-NEXT: } // inline namespace _V2
473473
// CHECK-NEXT: } // namespace TestNamespace
474474
// CHECK: static constexpr auto __sycl_shim10() {
@@ -501,7 +501,7 @@ namespace Testing::Tests {
501501
// CHECK-NEXT: }
502502

503503
// CHECK: namespace {
504-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
504+
// CHECK-NEXT: template <typename T> void templated(T, T);
505505
// CHECK-NEXT: } // namespace
506506
// CHECK: static constexpr auto __sycl_shim11() {
507507
// CHECK-NEXT: return (void (*)(float, float))templated<float>;
@@ -533,7 +533,7 @@ namespace Testing::Tests {
533533
// CHECK-NEXT: }
534534

535535
// CHECK: struct TestStruct;
536-
// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple> , T end);
536+
// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple, <>>, T);
537537
// CHECK-NEXT: static constexpr auto __sycl_shim12() {
538538
// CHECK-NEXT: return (void (*)(struct ns::Arg<struct TestStruct, float, 3, struct ns::notatuple>, struct TestStruct))templated<struct TestStruct>;
539539
// CHECK-NEXT:}
@@ -565,7 +565,7 @@ namespace Testing::Tests {
565565

566566
// CHECK: class BaseClass;
567567
// CHECK: namespace {
568-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
568+
// CHECK-NEXT: template <typename T> void templated(T, T);
569569
// CHECK-NEXT: } // namespace
570570
// CHECK: static constexpr auto __sycl_shim13() {
571571
// CHECK-NEXT: return (void (*)(class BaseClass, class BaseClass))templated<class BaseClass>;
@@ -598,7 +598,7 @@ namespace Testing::Tests {
598598

599599
// CHECK: class ChildOne;
600600
// CHECK: namespace {
601-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
601+
// CHECK-NEXT: template <typename T> void templated(T, T);
602602
// CHECK-NEXT: } // namespace
603603
// CHECK: static constexpr auto __sycl_shim14() {
604604
// CHECK-NEXT: return (void (*)(class ChildOne, class ChildOne))templated<class ChildOne>;
@@ -631,7 +631,7 @@ namespace Testing::Tests {
631631

632632
// CHECK: class ChildTwo;
633633
// CHECK: namespace {
634-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
634+
// CHECK-NEXT: template <typename T> void templated(T, T);
635635
// CHECK-NEXT: } // namespace
636636
// CHECK: static constexpr auto __sycl_shim15() {
637637
// CHECK-NEXT: return (void (*)(class ChildTwo, class ChildTwo))templated<class ChildTwo>;
@@ -664,7 +664,7 @@ namespace Testing::Tests {
664664

665665
// CHECK: class ChildThree;
666666
// CHECK: namespace {
667-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
667+
// CHECK-NEXT: template <typename T> void templated(T, T);
668668
// CHECK-NEXT: } // namespace
669669
// CHECK: static constexpr auto __sycl_shim16() {
670670
// CHECK-NEXT: return (void (*)(class ChildThree, class ChildThree))templated<class ChildThree>;
@@ -699,7 +699,7 @@ namespace Testing::Tests {
699699
// CHECK-NEXT: template <int dim> struct id;
700700
// CHECK-NEXT: }}
701701
// CHECK: namespace {
702-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
702+
// CHECK-NEXT: template <typename T> void templated(T, T);
703703
// CHECK-NEXT: } // namespace
704704
// CHECK: static constexpr auto __sycl_shim17() {
705705
// CHECK-NEXT: return (void (*)(struct sycl::id<2>, struct sycl::id<2>))templated<struct sycl::id<2>>;
@@ -734,7 +734,7 @@ namespace Testing::Tests {
734734
// CHECK-NEXT: template <int dim> struct range;
735735
// CHECK-NEXT: }}
736736
// CHECK: namespace {
737-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
737+
// CHECK-NEXT: template <typename T> void templated(T, T);
738738
// CHECK-NEXT: } // namespace
739739
// CHECK: static constexpr auto __sycl_shim18() {
740740
// CHECK-NEXT: return (void (*)(struct sycl::range<3>, struct sycl::range<3>))templated<struct sycl::range<3>>;
@@ -766,7 +766,7 @@ namespace Testing::Tests {
766766
// CHECK-NEXT: }
767767

768768
// CHECK: namespace {
769-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
769+
// CHECK-NEXT: template <typename T> void templated(T, T);
770770
// CHECK-NEXT: } // namespace
771771
// CHECK: static constexpr auto __sycl_shim19() {
772772
// CHECK-NEXT: return (void (*)(int *, int *))templated<int *>;
@@ -798,7 +798,7 @@ namespace Testing::Tests {
798798
// CHECK-NEXT: }
799799

800800
// CHECK: namespace {
801-
// CHECK-NEXT: template <typename T> void templated(T start, T end);
801+
// CHECK-NEXT: template <typename T> void templated(T, T);
802802
// CHECK-NEXT: } // namespace
803803
// CHECK: static constexpr auto __sycl_shim20() {
804804
// CHECK-NEXT: return (void (*)(struct sycl::X<class ChildTwo>, struct sycl::X<class ChildTwo>))templated<struct sycl::X<class ChildTwo>>;
@@ -835,7 +835,7 @@ namespace Testing::Tests {
835835
// CHECK-NEXT: }}}
836836
// CHECK: namespace TestNamespace {
837837
// CHECK-NEXT: inline namespace _V1 {
838-
// CHECK-NEXT: template <typename T, int a> void templated1(ns::Arg<T, float, a, ns::notatuple> , T end);
838+
// CHECK-NEXT: template <typename T, int a> void templated1(ns::Arg<T, float, a, ns::notatuple, <>>, T);
839839
// CHECK-NEXT: } // inline namespace _V1
840840
// CHECK-NEXT: } // namespace TestNamespace
841841
// CHECK: static constexpr auto __sycl_shim21() {
@@ -867,7 +867,7 @@ namespace Testing::Tests {
867867
// CHECK-NEXT: };
868868
// CHECK-NEXT: }
869869

870-
// CHECK: template <typename ... Args> void variadic_templated(Args... args);
870+
// CHECK: template <typename ... Args> void variadic_templated(Args...);
871871
// CHECK-NEXT: static constexpr auto __sycl_shim22() {
872872
// CHECK-NEXT: return (void (*)(int, float, char))variadic_templated<int, float, char>;
873873
// CHECK-NEXT: }
@@ -897,7 +897,7 @@ namespace Testing::Tests {
897897
// CHECK-NEXT: };
898898
// CHECK-NEXT: }
899899

900-
// CHECK: template <typename ... Args> void variadic_templated(Args... args);
900+
// CHECK: template <typename ... Args> void variadic_templated(Args...);
901901
// CHECK-NEXT: static constexpr auto __sycl_shim23() {
902902
// CHECK-NEXT: return (void (*)(int, float, char, int))variadic_templated<int, float, char, int>;
903903
// CHECK-NEXT: }
@@ -927,7 +927,7 @@ namespace Testing::Tests {
927927
// CHECK-NEXT: };
928928
// CHECK-NEXT: }
929929

930-
// CHECK: template <typename ... Args> void variadic_templated(Args... args);
930+
// CHECK: template <typename ... Args> void variadic_templated(Args...);
931931
// CHECK-NEXT: static constexpr auto __sycl_shim24() {
932932
// CHECK-NEXT: return (void (*)(float, float))variadic_templated<float, float>;
933933
// CHECK-NEXT: }
@@ -957,7 +957,7 @@ namespace Testing::Tests {
957957
// CHECK-NEXT: };
958958
// CHECK-NEXT: }
959959

960-
// CHECK: template <typename T, typename ... Args> void variadic_templated1(T b, Args... args);
960+
// CHECK: template <typename T, typename ... Args> void variadic_templated1(T, Args...);
961961
// CHECK-NEXT: static constexpr auto __sycl_shim25() {
962962
// CHECK-NEXT: return (void (*)(float, char, char))variadic_templated1<float, char, char>;
963963
// CHECK-NEXT: }
@@ -987,7 +987,7 @@ namespace Testing::Tests {
987987
// CHECK-NEXT: };
988988
// CHECK-NEXT: }
989989

990-
// CHECK: template <typename T, typename ... Args> void variadic_templated1(T b, Args... args);
990+
// CHECK: template <typename T, typename ... Args> void variadic_templated1(T, Args...);
991991
// CHECK-NEXT: static constexpr auto __sycl_shim26() {
992992
// CHECK-NEXT: return (void (*)(int, float, char))variadic_templated1<int, float, char>;
993993
// CHECK-NEXT: }
@@ -1019,7 +1019,7 @@ namespace Testing::Tests {
10191019

10201020
// CHECK: namespace Testing {
10211021
// CHECK-NEXT: namespace Tests {
1022-
// CHECK-NEXT: template <typename T, typename ... Args> void variadic_templated(T b, Args... args);
1022+
// CHECK-NEXT: template <typename T, typename ... Args> void variadic_templated(T, Args...);
10231023
// CHECK-NEXT: } // namespace Tests
10241024
// CHECK-NEXT: } // namespace Testing
10251025
// CHECK: static constexpr auto __sycl_shim27() {
@@ -1053,7 +1053,7 @@ namespace Testing::Tests {
10531053

10541054
// CHECK: namespace Testing {
10551055
// CHECK-NEXT: namespace Tests {
1056-
// CHECK-NEXT: template <typename T, typename ... Args> void variadic_templated(T b, Args... args);
1056+
// CHECK-NEXT: template <typename T, typename ... Args> void variadic_templated(T, Args...);
10571057
// CHECK-NEXT: } // namespace Tests
10581058
// CHECK-NEXT: } // namespace Testing
10591059
// CHECK: static constexpr auto __sycl_shim28() {

0 commit comments

Comments
 (0)