Skip to content

Commit e57c90c

Browse files
committed
[SYCL] do not emit diagnostics if class with virtual method is used as argument type
1 parent 778314d commit e57c90c

File tree

4 files changed

+122
-92
lines changed

4 files changed

+122
-92
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12970,7 +12970,7 @@ def err_free_function_first_occurrence_missing_attr: Error<
1297012970
def err_free_function_class_method : Error<
1297112971
"%select{static |}0class method cannot be used to define a SYCL kernel free function kernel">;
1297212972
def err_free_function_virtual_arg : Error<
12973-
"argument type %0 of kernel free function can not %select{be virtually derived |have virtual methods}1">;
12973+
"argument type '%0' virtually inherited from base class is not supported as a SYCL kernel free function kernel arguments">;
1297412974

1297512975

1297612976
// SYCL kernel entry point diagnostics

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,26 @@ class KernelObjVisitor {
15741574
else if (ParamTy->isStructureOrClassType()) {
15751575
if (KP_FOR_EACH(handleStructType, Param, ParamTy)) {
15761576
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
1577+
llvm::SmallVector<const CXXRecordDecl *, 8> WorkList;
1578+
llvm::SmallPtrSet<const CXXRecordDecl *, 8> Visited;
1579+
if (RD)
1580+
WorkList.push_back(RD);
1581+
while (!WorkList.empty()) {
1582+
const CXXRecordDecl *Cur = WorkList.pop_back_val();
1583+
if (!Cur || !Visited.insert(Cur).second)
1584+
continue;
1585+
for (const auto &Base : Cur->bases()) {
1586+
if (Base.isVirtual()) {
1587+
SemaSYCLRef.SemaRef.Diag(Param->getLocation(),
1588+
diag::err_free_function_virtual_arg)
1589+
<< Cur->getNameAsString() << Param->getSourceRange();
1590+
return;
1591+
}
1592+
const CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
1593+
if (BaseDecl)
1594+
WorkList.push_back(BaseDecl);
1595+
}
1596+
}
15771597
visitRecord(nullptr, Param, RD, ParamTy, Handlers...);
15781598
}
15791599
} else if (ParamTy->isUnionType())
@@ -5848,44 +5868,6 @@ void SemaSYCL::MarkDevices() {
58485868
checkSYCLAddIRAttributesFunctionAttrConflicts(T.GetSYCLKernel());
58495869
}
58505870
}
5851-
5852-
static bool hasVirtuals(Sema &S, ParmVarDecl *Param) {
5853-
const clang::CXXRecordDecl *ParamType =
5854-
Param->getType()->getAsCXXRecordDecl();
5855-
if (!ParamType || !ParamType->isThisDeclarationADefinition())
5856-
return false;
5857-
5858-
llvm::SmallVector<const clang::CXXRecordDecl *, 8> WorkList;
5859-
llvm::SmallPtrSet<const clang::CXXRecordDecl *, 8> Visited;
5860-
WorkList.push_back(ParamType);
5861-
5862-
while (!WorkList.empty()) {
5863-
const clang::CXXRecordDecl *Cur = WorkList.pop_back_val();
5864-
if (!Visited.insert(Cur).second)
5865-
continue;
5866-
5867-
// Check for virtual bases
5868-
for (const clang::CXXBaseSpecifier &Base : Cur->bases()) {
5869-
if (Base.isVirtual())
5870-
return S.Diag(Param->getLocation(), diag::err_free_function_virtual_arg)
5871-
<< ParamType->getNameAsString() << 0 << Param->getSourceRange();
5872-
const clang::CXXRecordDecl *BaseDecl =
5873-
Base.getType()->getAsCXXRecordDecl();
5874-
if (BaseDecl && BaseDecl->isThisDeclarationADefinition())
5875-
WorkList.push_back(BaseDecl);
5876-
}
5877-
5878-
// Check for virtual member functions
5879-
for (const auto *Method : Cur->methods()) {
5880-
if (Method->isVirtual()) {
5881-
return S.Diag(Param->getLocation(), diag::err_free_function_virtual_arg)
5882-
<< ParamType->getNameAsString() << 1 << Param->getSourceRange();
5883-
}
5884-
}
5885-
}
5886-
return false;
5887-
}
5888-
58895871
static bool CheckFreeFunctionDiagnostics(Sema &S, const FunctionDecl *FD) {
58905872
if (FD->isVariadic()) {
58915873
return S.Diag(FD->getLocation(), diag::err_free_function_variadic_args);
@@ -5908,8 +5890,6 @@ static bool CheckFreeFunctionDiagnostics(Sema &S, const FunctionDecl *FD) {
59085890
diag::err_free_function_with_default_arg)
59095891
<< Param->getSourceRange();
59105892
}
5911-
if (hasVirtuals(S, Param))
5912-
return true;
59135893
}
59145894
return false;
59155895
}

clang/test/SemaSYCL/free_function_negative.cpp

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,24 @@ static void StaticsingleTaskKernelMethod(int Value) {
119119

120120
};
121121

122+
122123
class Base {};
123124
class Derived : virtual public Base {};
124125

125-
// expected-error@+2 {{argument type Derived of kernel free function can not be virtually derived}}
126+
// expected-error@+2 2 {{argument type 'Derived' virtually inherited from base class is not supported as a SYCL kernel free function kernel arguments}}
126127
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
127128
void VirtualInheritArg(Derived Value) {
128129
}
129130

130-
// expected-error@+2 {{argument type Derived of kernel free function can not be virtually derived}}
131+
// expected-error@+2 4 {{argument type 'Derived' virtually inherited from base class is not supported as a SYCL kernel free function kernel arguments}}
131132
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
132133
void VirtualInheritArg1(int a, Derived Value, float b, Derived Value1) {
133134
}
134135

135136
class Derived1 : public Derived {
136137
};
137138

138-
// expected-error@+2 {{argument type Derived1 of kernel free function can not be virtually derived}}
139+
// expected-error@+2 2 {{argument type 'Derived' virtually inherited from base class is not supported as a SYCL kernel free function kernel arguments}}
139140
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
140141
void VirtualInheritArg2(Derived1 Value) {
141142
}
@@ -144,55 +145,8 @@ class Base1 {};
144145
class Derived2 : public Base1, public virtual Base {
145146
};
146147

147-
// expected-error@+2 {{argument type Derived2 of kernel free function can not be virtually derived}}
148+
// expected-error@+2 2 {{argument type 'Derived2' virtually inherited from base class is not supported as a SYCL kernel free function kernel arguments}}
148149
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
149150
void VirtualInheritArg3(Derived2 Value) {
150151
}
151152

152-
struct Derived3 : virtual public Base {
153-
};
154-
155-
// expected-error@+2 {{argument type Derived3 of kernel free function can not be virtually derived}}
156-
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
157-
void VirtualInheritArg4(Derived3 Value) {
158-
}
159-
160-
class Class {
161-
public:
162-
virtual void virtualMethod() {}
163-
};
164-
165-
// expected-error@+2 {{argument type Class of kernel free function can not have virtual methods}}
166-
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
167-
void VirtualMethodArg1(Class Value) {
168-
}
169-
170-
class Class1: public Class {
171-
public:
172-
void NewMethod() {}
173-
};
174-
175-
// expected-error@+2 {{argument type Class1 of kernel free function can not have virtual methods}}
176-
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 2)]]
177-
void VirtualMethodArg2(Class1 Value) {
178-
}
179-
180-
class Class2: Base1 {
181-
public:
182-
virtual void virtualMethod() {}
183-
};
184-
185-
// expected-error@+2 {{argument type Class2 of kernel free function can not have virtual methods}}
186-
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
187-
void VirtualMethodArg2(Class2 Value) {
188-
}
189-
190-
class Class3 {
191-
public:
192-
virtual ~Class3() {}
193-
};
194-
195-
// expected-error@+2 {{argument type Class3 of kernel free function can not have virtual methods}}
196-
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
197-
void VirtualMethodArg3(Class3 Value) {
198-
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <iostream>
5+
#include <sycl/detail/core.hpp>
6+
#include <sycl/ext/oneapi/free_function_queries.hpp>
7+
#include <sycl/kernel_bundle.hpp>
8+
#include <sycl/usm.hpp>
9+
10+
namespace syclext = sycl::ext::oneapi;
11+
namespace syclexp = sycl::ext::oneapi::experimental;
12+
13+
static constexpr size_t NUM = 1024;
14+
static constexpr size_t WGSIZE = 16;
15+
static constexpr auto FFTestMark = "Free function Kernel Test:";
16+
static constexpr float offset = 1.1f;
17+
18+
class Base {
19+
public:
20+
virtual void virtual_method(float start) = 0;
21+
virtual ~Base() = default;
22+
};
23+
24+
class TestClass : public Base {
25+
float data = 0.0f;
26+
27+
public:
28+
void virtual_method(float start) override {}
29+
30+
float calculate(float start, size_t id) {
31+
return start + static_cast<float>(id) + data;
32+
}
33+
34+
void setData(float value) { data = value; }
35+
};
36+
37+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>))
38+
void func_range(TestClass *acc, float *ptr) {
39+
size_t id = syclext::this_work_item::get_nd_item<1>().get_global_linear_id();
40+
ptr[id] = acc->calculate(3.14f, id);
41+
}
42+
43+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel))
44+
void func_single(TestClass *acc, float *ptr) {
45+
size_t id = syclext::this_work_item::get_nd_item<1>().get_global_linear_id();
46+
ptr[id] = acc->calculate(3.14f, id);
47+
}
48+
49+
int check_result(float *ptr) {
50+
for (size_t i = 0; i < NUM; ++i) {
51+
const float expected = 3.14f + static_cast<float>(i) + offset;
52+
if (ptr[i] != expected)
53+
return 1;
54+
}
55+
return 0;
56+
}
57+
58+
static int call_kernel_code(sycl::queue &q, sycl::kernel &kernel) {
59+
float *ptr = sycl::malloc_shared<float>(NUM, q);
60+
TestClass *obj = sycl::malloc_shared<TestClass>(1, q);
61+
obj->setData(offset);
62+
63+
q.submit([&](sycl::handler &cgh) {
64+
cgh.set_args(obj, ptr);
65+
sycl::nd_range ndr{{NUM}, {WGSIZE}};
66+
cgh.parallel_for(ndr, kernel);
67+
}).wait();
68+
int ret = check_result(ptr);
69+
sycl::free(ptr, q);
70+
sycl::free(obj, q);
71+
return ret;
72+
}
73+
74+
template <auto *Func>
75+
int test_arg_with_virtual_method(sycl::queue &q, sycl::context &ctxt,
76+
std::string_view name) {
77+
auto exe_bndl =
78+
syclexp::get_kernel_bundle<Func, sycl::bundle_state::executable>(ctxt);
79+
sycl::kernel k_func = exe_bndl.template ext_oneapi_get_kernel<Func>();
80+
const int ret = call_kernel_code(q, k_func);
81+
if (ret != 0)
82+
std::cerr << FFTestMark << name << " failed\n";
83+
return ret;
84+
}
85+
86+
int main() {
87+
sycl::queue q;
88+
sycl::context ctxt = q.get_context();
89+
sycl::device dev = q.get_device();
90+
91+
int ret =
92+
test_arg_with_virtual_method<func_range>(q, ctxt, "virtual_method_range");
93+
ret |= test_arg_with_virtual_method<func_single>(q, ctxt,
94+
"virtual_method_single");
95+
return ret;
96+
}

0 commit comments

Comments
 (0)