Skip to content

Commit 888dcfe

Browse files
committed
[SYCL][WIP] initial version of get kernel info num_args
1 parent dedadc1 commit 888dcfe

File tree

5 files changed

+154
-7
lines changed

5 files changed

+154
-7
lines changed

sycl/include/sycl/ext/oneapi/get_kernel_info.hpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <sycl/queue.hpp>
1717

1818
#include <vector>
19+
#include <type_traits>
1920

2021
namespace sycl {
2122
inline namespace _V1 {
@@ -69,13 +70,30 @@ get_kernel_info(const context &ctxt) {
6970

7071
template <auto *Func, typename Param>
7172
std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>,
72-
typename sycl::detail::is_kernel_device_specific_info_desc<
73-
Param>::return_type>
73+
typename sycl::detail::is_kernel_device_specific_info_desc<
74+
Param>::return_type>
7475
get_kernel_info(const context &ctxt, const device &dev) {
7576
auto Bundle = sycl::ext::oneapi::experimental::get_kernel_bundle<
76-
Func, sycl::bundle_state::executable>(ctxt);
77+
Func, sycl::bundle_state::executable>(ctxt);
7778
return Bundle.template ext_oneapi_get_kernel<Func>().template get_info<Param>(
78-
dev);
79+
dev);
80+
}
81+
82+
83+
template <auto *Func, typename Param>
84+
std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>, size_t>
85+
get_kernel_info(const context &ctxt, const device &dev) {
86+
if constexpr (std::is_same_v<Param, sycl::info::kernel::num_args>) {
87+
auto Bundle = sycl::ext::oneapi::experimental::get_kernel_bundle<
88+
Func, sycl::bundle_state::executable>(ctxt);
89+
auto kernel_id = sycl::ext::oneapi::experimental::get_kernel_id<Func>();;
90+
sycl::kernel kernel = Bundle.get_kernel(kernel_id);
91+
return Bundle.template ext_oneapi_get_kernel<Func>()
92+
.template get_info<Param>();
93+
}
94+
sycl::exception(
95+
sycl::make_error_code(sycl::errc::invalid),
96+
"get_kernel_info is not supported for free function kernels.");
7997
}
8098

8199
template <auto *Func, typename Param>

sycl/source/detail/kernel_impl.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include <memory>
1414

15+
#include <iostream>
16+
1517
namespace sycl {
1618
inline namespace _V1 {
1719
namespace detail {
@@ -104,7 +106,12 @@ std::string_view kernel_impl::getName() const {
104106
}
105107

106108
bool kernel_impl::isBuiltInKernel(const device &Device) const {
109+
std::cout << "isBuiltInKernel" << std::endl;
107110
auto BuiltInKernels = Device.get_info<info::device::built_in_kernel_ids>();
111+
std::cout << "Built-in kernels available on the device:" << std::endl;
112+
for (const auto &kernel_id : BuiltInKernels) {
113+
std::cout << " " << kernel_id.get_name() << std::endl;
114+
}
108115
if (BuiltInKernels.empty())
109116
return false;
110117
std::string KernelName = get_info<info::kernel::function_name>();
@@ -113,6 +120,17 @@ bool kernel_impl::isBuiltInKernel(const device &Device) const {
113120
[&KernelName](kernel_id &Id) { return Id.get_name() == KernelName; }));
114121
}
115122

123+
bool kernel_impl::isFreeFunctionKernel() const {
124+
const auto ids = MKernelBundleImpl->get_kernel_ids();
125+
return std::any_of(
126+
ids.begin(), ids.end(),
127+
[this](const kernel_id &Id) {
128+
const std::string KernelName = Id.get_name();
129+
const auto pos = KernelName.find("__sycl_kernel_");
130+
return pos != std::string::npos;
131+
});
132+
}
133+
116134
void kernel_impl::checkIfValidForNumArgsInfoQuery() const {
117135
if (isInteropOrSourceBased())
118136
return;
@@ -121,6 +139,10 @@ void kernel_impl::checkIfValidForNumArgsInfoQuery() const {
121139
[this](device &Device) { return isBuiltInKernel(Device); }))
122140
return;
123141

142+
if (isFreeFunctionKernel())
143+
return;
144+
std::cout << "checkIfValidForNumArgsInfoQuery has not finded" << std::endl;
145+
124146
throw sycl::exception(
125147
sycl::make_error_code(errc::invalid),
126148
"info::kernel::num_args descriptor may only be used to query a kernel "

sycl/source/detail/kernel_impl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class kernel_impl {
256256
mutable std::string MName;
257257

258258
bool isBuiltInKernel(const device &Device) const;
259+
bool isFreeFunctionKernel() const;
259260
void checkIfValidForNumArgsInfoQuery() const;
260261

261262
/// Check if the occupancy limits are exceeded for the given kernel launch
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// REQUIRES: level_zero, level_zero_dev_kit
2+
// RUN: %{build} %level_zero_options -o %t.ze.out
3+
// RUN: %{run} %t.ze.out
4+
5+
#include <iostream>
6+
#include <sycl/detail/core.hpp>
7+
#include <sycl/ext/oneapi/get_kernel_info.hpp>
8+
#include <sycl/kernel_bundle.hpp>
9+
#include <sycl/usm.hpp>
10+
11+
12+
namespace syclext = sycl::ext::oneapi;
13+
namespace syclexp = sycl::ext::oneapi::experimental;
14+
15+
static constexpr size_t NUM = 1024;
16+
static constexpr size_t WGSIZE = 16;
17+
static constexpr auto FFTestMark = "Free function Kernel Test:";
18+
19+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>))
20+
void func_range(float start, float *ptr) {}
21+
22+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel))
23+
void func_single(float start, float *ptr) {}
24+
25+
26+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel))
27+
void kernel_func(sycl::item<1> idx, float value, sycl::accessor<int, 1> acc) {
28+
29+
}
30+
31+
template <typename T>
32+
static void call_kernel_code(sycl::queue &q, sycl::kernel &kernel) {
33+
T *ptr = sycl::malloc_shared<T>(NUM, q);
34+
q.submit([&](sycl::handler &cgh) {
35+
cgh.set_args(3.14f, ptr);
36+
sycl::nd_range ndr{{NUM}, {WGSIZE}};
37+
cgh.parallel_for(ndr, kernel);
38+
}).wait();
39+
sycl::free(ptr, q);
40+
}
41+
42+
template <auto *Func>
43+
int test_num_args(sycl::context &ctxt, const int expected_num_args) {
44+
const int actual =
45+
syclexp::get_kernel_info<Func, sycl::info::kernel::num_args>(ctxt);
46+
const bool res = actual == expected_num_args;
47+
if (!res)
48+
std::cout << FFTestMark << "test_num_args failed: expected_num_args "
49+
<< expected_num_args << "actual " << actual << std::endl;
50+
return res;
51+
}
52+
53+
int main() {
54+
sycl::queue q;
55+
sycl::context ctx = q.get_context();
56+
sycl::device dev = q.get_device();
57+
58+
auto bundle_range =
59+
syclexp::get_kernel_bundle<func_single, sycl::bundle_state::executable>(
60+
ctx);
61+
62+
auto actual =
63+
syclexp::get_kernel_info<func_single, sycl::info::kernel::num_args>(
64+
ctx, dev);
65+
66+
std::cout << "Actual number of arguments: " << actual << std::endl;
67+
assert(actual == 2 && "kernel should take 2 args");
68+
return 0;
69+
}

sycl/test-e2e/KernelAndProgram/free_function_apis.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// REQUIRES: aspect-usm_shared_allocations
2-
// RUN: %{build} -o %t.out
2+
// RUN: %{build} --save-temps -o %t.out
33
// RUN: %{run} %t.out
44

55
#include <iostream>
@@ -21,6 +21,11 @@ void ff_2(int *ptr, int start) {
2121
ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + start;
2222
}
2323

24+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
25+
(ext::oneapi::experimental::nd_range_kernel<2>))
26+
void ff_b(int *ptr, int start) {
27+
}
28+
2429
// Templated free function definition.
2530
template <typename T>
2631
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
@@ -96,6 +101,21 @@ bool test_kernel_apis(queue Queue) {
96101
return Pass;
97102
}
98103

104+
template <auto *Func>
105+
bool check_kernel_id(const sycl::kernel k, const sycl::context &ctx) {
106+
namespace exp = ext::oneapi::experimental;
107+
const auto id = exp::get_kernel_id<Func>();
108+
auto exe_bndl =
109+
exp::get_kernel_bundle<Func, sycl::bundle_state::executable>(ctx);
110+
if (!exe_bndl.has_kernel(id))
111+
return false;
112+
const auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
113+
const auto kernel_ids = kb.get_kernel_ids();
114+
bool ret =
115+
std::find(kernel_ids.begin(), kernel_ids.end(), id) != kernel_ids.end();
116+
return ret;
117+
}
118+
99119
bool test_bundle_apis(queue Queue) {
100120
bool Pass = true;
101121

@@ -133,6 +153,10 @@ bool test_bundle_apis(queue Queue) {
133153
std::cout << "PassE=" << PassE << std::endl;
134154
Pass &= PassE;
135155

156+
bool PassE2 = ext::oneapi::experimental::is_compatible<ff_2>(Device);
157+
std::cout << "PassE2=" << PassE2 << std::endl;
158+
Pass &= PassE2;
159+
136160
// Check that ff_2 is found in bundle.
137161
kernel_bundle Bundle2 = ext::oneapi::experimental::get_kernel_bundle<
138162
ff_2, bundle_state::executable>(Context);
@@ -144,7 +168,7 @@ bool test_bundle_apis(queue Queue) {
144168
std::cout << "PassG=" << PassG << std::endl;
145169
Pass &= PassG;
146170
kernel Kernel2 = Bundle2.ext_oneapi_get_kernel<ff_2>();
147-
bool PassH = true;
171+
bool PassH = check_kernel_id<ff_2>(Kernel2, Context);
148172
std::cout << "PassH=" << PassH << std::endl;
149173
Pass &= PassH;
150174

@@ -161,7 +185,8 @@ bool test_bundle_apis(queue Queue) {
161185
Pass &= PassJ;
162186
kernel Kernel3 =
163187
Bundle3.ext_oneapi_get_kernel<(void (*)(int *, int))ff_3<int>>();
164-
bool PassK = true;
188+
bool PassK =
189+
check_kernel_id<(void (*)(int *, int))ff_3<int>>(Kernel3, Context);
165190
std::cout << "PassK=" << PassK << std::endl;
166191
Pass &= PassK;
167192

@@ -196,6 +221,18 @@ bool test_bundle_apis(queue Queue) {
196221
std::cout << "PassP=" << PassP << std::endl;
197222
Pass &= PassP;
198223

224+
bool PassO1 = false;
225+
try
226+
{
227+
kernel Kernel51 = Bundle5.ext_oneapi_get_kernel<ff_b>();
228+
std::cout <<"Wrong PATH" << std::endl;
229+
} catch (const sycl::exception &e) {
230+
PassO1 = e.code() == sycl::errc::invalid;
231+
}
232+
std::cout << "PassO1=" << PassO1 << std::endl;
233+
Pass &= PassO1;
234+
235+
199236
bool PassQ =
200237
Bundle6.ext_oneapi_has_kernel<(void (*)(int *, int))ff_3<int>>(Device);
201238
std::cout << "PassQ=" << PassQ << std::endl;

0 commit comments

Comments
 (0)