diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index fa773e7c7bb3e..a22e8405372de 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -13038,6 +13038,8 @@ def err_free_function_first_occurrence_missing_attr: Error< "the first occurrence of SYCL kernel free function should be declared with 'sycl-nd-range-kernel' or 'sycl-single-task-kernel' compile time properties">; def err_free_function_class_method : Error< "%select{static |}0class method cannot be used to define a SYCL kernel free function kernel">; +def err_sycl_kernel_virtual_arg : Error< + "argument type '%0' virtually inherited from base class `%1` is not supported as a SYCL kernel argument">; // SYCL kernel entry point diagnostics diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index edb3aab35e2d5..c86e49a1be33b 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -1258,12 +1258,19 @@ constructFreeFunctionKernelName(const FunctionDecl *FreeFunc, MC.mangleName(FreeFunc, Out); std::string MangledName(Out.str()); size_t StartNums = MangledName.find_first_of("0123456789"); - size_t EndNums = MangledName.find_first_not_of("0123456789", StartNums); - size_t NameLength = - std::stoi(MangledName.substr(StartNums, EndNums - StartNums)); - size_t NewNameLength = 14 /*length of __sycl_kernel_*/ + NameLength; - NewName = MangledName.substr(0, StartNums) + std::to_string(NewNameLength) + - "__sycl_kernel_" + MangledName.substr(EndNums); + if (StartNums == std::string::npos) { + // Microsoft mangling name has template like ?FunctionName@@YAXH@Z + NewName = + MangledName.substr(0, 1) + "sycl_kernel_" + MangledName.substr(1); + } else { + size_t EndNums = MangledName.find_first_not_of("0123456789", StartNums); + size_t NameLength = + std::stoi(MangledName.substr(StartNums, EndNums - StartNums)); + size_t NewNameLength = 14 /*length of __sycl_kernel_*/ + NameLength; + NewName = MangledName.substr(0, StartNums) + + std::to_string(NewNameLength) + "__sycl_kernel_" + + MangledName.substr(EndNums); + } } StableName = NewName; return {NewName, StableName}; @@ -1932,6 +1939,10 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { // class is entered. int StructBaseDepth = -1; + // Used to track FunctionDecl location in case if it is not available directly + // from method + SourceLocation FreeFunctionLoc; + // Check whether the object should be disallowed from being copied to kernel. // Return true if not copyable, false if copyable. bool checkNotCopyableToKernel(const FieldDecl *FD, QualType FieldTy) { @@ -2045,8 +2056,13 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { } public: - SyclKernelFieldChecker(SemaSYCL &S) - : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {} + /// Constructor for the SyclKernelFieldChecker + /// \param S The SemaSYCL reference used for diagnostics and context. + /// \param FFLoc Free function location, used to report diagnostics + explicit SyclKernelFieldChecker(SemaSYCL &S, + SourceLocation FFLoc = SourceLocation()) + : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()), + FreeFunctionLoc(FFLoc) {} static constexpr const bool VisitNthArrayElement = false; bool isValid() { return !IsInvalid; } @@ -2206,10 +2222,20 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { return true; } - bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, + bool leaveStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &B, QualType) final { --StructBaseDepth; - return true; + // FreeFunctionLoc.isInvalid() shows if checker object was created for a + // free function. If that is the case, point to the free function + // declaration. + if (B.isVirtual()) { + Diag.Report(FreeFunctionLoc.isInvalid() ? RD->getLocation() + : FreeFunctionLoc, + diag::err_sycl_kernel_virtual_arg) + << RD->getNameAsString() << B.getType().getAsString(); + IsInvalid = true; + } + return isValid(); } }; @@ -5900,7 +5926,7 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) { FreeFunctionDeclarations.erase(FD->getCanonicalDecl()); SyclKernelDecompMarker DecompMarker(*this); - SyclKernelFieldChecker FieldChecker(*this); + SyclKernelFieldChecker FieldChecker(*this, FD->getLocation()); SyclKernelUnionChecker UnionChecker(*this); KernelObjVisitor Visitor{*this}; diff --git a/clang/test/SemaSYCL/free_function_negative.cpp b/clang/test/SemaSYCL/free_function_negative.cpp index d7b03ddec04ed..7b1f3f6e36f92 100644 --- a/clang/test/SemaSYCL/free_function_negative.cpp +++ b/clang/test/SemaSYCL/free_function_negative.cpp @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -verify=expected -fsycl-int-header=%t.h %s +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -verify=expected %s #include "sycl.hpp" @@ -118,3 +118,78 @@ static void StaticsingleTaskKernelMethod(int Value) { } }; + +class Base {}; +class Derived : virtual public Base {}; + +// expected-error@+2 {{argument type 'Derived' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg(Derived Value) { +} + +// expected-error@+2 1 {{argument type 'Derived' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg1(int a, Derived Value, float b, Derived Value1) { +} + +class Derived1 : public Derived { +}; + +// expected-error@+2 {{argument type 'Derived' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg2(Derived1 Value) { +} + +class Base1 {}; +class Derived2 : public Base1, public virtual Base { +}; + +// expected-error@+2 {{argument type 'Derived2' virtually inherited from base class `Base` is not supported as a SYCL kernel argumen}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg3(Derived2 Value) { +} + +template +class Derived3 : virtual T { +}; + +// expected-error@+2 {{argument type 'Derived3' virtually inherited from base class `class Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg4(Derived3 Value) { +} + +// expected-error@+3 {{argument type 'Derived3' virtually inherited from base class `class Derived2` is not supported as a SYCL kernel argument}} +// expected-error@+2 {{argument type 'Derived2' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg5(Derived3 Value) { +} + +template +class Derived4 : T { +}; + +// expected-error@+2 {{argument type 'Derived' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg6(Derived4 Value) { +} + +// expected-error@+2 {{argument type 'Derived2' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg7(Derived4 Value) { +} + + +template +class Derived5 : T, virtual Base { +}; + +// expected-error@+2 {{argument type 'Derived5' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg7(Derived5 Value) { +} + +// expected-error@+3 {{argument type 'Derived5' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +// expected-error@+2 {{argument type 'Derived' virtually inherited from base class `Base` is not supported as a SYCL kernel argument}} +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]] +void VirtualInheritArg7(Derived5 Value) { +} diff --git a/sycl/test-e2e/FreeFunctionKernels/virtual_methods.cpp b/sycl/test-e2e/FreeFunctionKernels/virtual_methods.cpp new file mode 100644 index 0000000000000..a887acaa2851a --- /dev/null +++ b/sycl/test-e2e/FreeFunctionKernels/virtual_methods.cpp @@ -0,0 +1,101 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +/* + * Test to check class/struct type with virtual methods as SYCL free function + * kernel arguments. + */ + +#include +#include +#include +#include +#include + +namespace syclext = sycl::ext::oneapi; +namespace syclexp = sycl::ext::oneapi::experimental; + +static constexpr size_t NUM = 1024; +static constexpr size_t WGSIZE = 16; +static constexpr auto FFTestMark = "Free function Kernel Test:"; +static constexpr float offset = 1.1f; + +class Base { +public: + virtual void virtual_method(float start) = 0; + virtual ~Base() = default; +}; + +class TestClass : public Base { + float data = 0.0f; + +public: + void virtual_method(float start) override {} + + float calculate(float start, size_t id) { + return start + static_cast(id) + data; + } + + void setData(float value) { data = value; } +}; + +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void func_range(TestClass *acc, float *ptr) { + size_t id = syclext::this_work_item::get_nd_item<1>().get_global_linear_id(); + ptr[id] = acc->calculate(3.14f, id); +} + +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel)) +void func_single(TestClass *acc, float *ptr) { + size_t id = syclext::this_work_item::get_nd_item<1>().get_global_linear_id(); + ptr[id] = acc->calculate(3.14f, id); +} + +int check_result(float *ptr) { + for (size_t i = 0; i < NUM; ++i) { + const float expected = 3.14f + static_cast(i) + offset; + if (ptr[i] != expected) + return 1; + } + return 0; +} + +int call_kernel_code(sycl::queue &q, sycl::kernel &kernel) { + float *ptr = sycl::malloc_shared(NUM, q); + TestClass *obj = sycl::malloc_shared(1, q); + obj->setData(offset); + + q.submit([&](sycl::handler &cgh) { + cgh.set_args(obj, ptr); + sycl::nd_range ndr{{NUM}, {WGSIZE}}; + cgh.parallel_for(ndr, kernel); + }).wait(); + int ret = check_result(ptr); + sycl::free(ptr, q); + sycl::free(obj, q); + return ret; +} + +template +int test_arg_with_virtual_method(sycl::queue &q, sycl::context &ctxt, + std::string_view name) { + auto exe_bndl = + syclexp::get_kernel_bundle(ctxt); + sycl::kernel k_func = exe_bndl.template ext_oneapi_get_kernel(); + int ret = call_kernel_code(q, k_func); + if (ret != 0) + std::cerr << FFTestMark << name << " failed\n"; + return ret; +} + +int main() { + sycl::queue q; + sycl::context ctxt = q.get_context(); + sycl::device dev = q.get_device(); + + int ret = + test_arg_with_virtual_method(q, ctxt, "virtual_method_range"); + ret |= test_arg_with_virtual_method(q, ctxt, + "virtual_method_single"); + return ret; +}