Skip to content

Commit 409d57d

Browse files
adarshyogaDiptorup Deb
authored andcommitted
fixing knn sycl for single precision execution
1 parent 66e2931 commit 409d57d

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
#include <CL/sycl.hpp>
66
#include <cmath>
77

8-
struct neighbors
8+
template <typename FpTy, typename IntTy> class theKernel;
9+
10+
template <typename FpTy> struct neighbors
911
{
10-
double dist;
12+
FpTy dist;
1113
size_t label;
1214
};
1315

14-
template <typename FpTy>
16+
template <typename FpTy, typename IntTy>
1517
sycl::event knn_impl(sycl::queue q,
1618
FpTy *d_train,
1719
size_t *d_train_labels,
@@ -20,18 +22,18 @@ sycl::event knn_impl(sycl::queue q,
2022
size_t classes_num,
2123
size_t train_size,
2224
size_t test_size,
23-
size_t *d_predictions,
25+
IntTy *d_predictions,
2426
FpTy *d_votes_to_classes,
2527
size_t data_dim)
2628
{
2729
sycl::event partial_hists_ev = q.submit([&](sycl::handler &h) {
28-
h.parallel_for<class theKernel>(
30+
h.parallel_for<theKernel<FpTy, IntTy>>(
2931
sycl::range<1>{test_size}, [=](sycl::id<1> myID) {
3032
size_t i = myID[0];
3133

3234
// here k has to be 5 in order to match with numpy no. of
3335
// neighbors
34-
struct neighbors queue_neighbors[5];
36+
struct neighbors<FpTy> queue_neighbors[5];
3537

3638
// count distances
3739
for (size_t j = 0; j < k; ++j) {
@@ -102,10 +104,10 @@ sycl::event knn_impl(sycl::queue q,
102104
queue_neighbors[j].label]++;
103105
}
104106

105-
size_t max_ind = 0;
107+
IntTy max_ind = 0;
106108
FpTy max_value = 0.0;
107109

108-
for (size_t j = 0; j < classes_num; ++j) {
110+
for (IntTy j = 0; j < (IntTy)classes_num; ++j) {
109111
if (d_votes_to_classes[i * classes_num + j] > max_value) {
110112
max_value = d_votes_to_classes[i * classes_num + j];
111113
max_ind = j;

dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include "_knn_kernel.hpp"
66
#include <dpctl4pybind11.hpp>
77

8-
#include <iostream>
9-
108
template <typename... Args> bool ensure_compatibility(const Args &...args)
119
{
1210
std::vector<dpctl::tensor::usm_ndarray> arrays = {args...};
@@ -41,16 +39,28 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train,
4139
votes_to_classes))
4240
throw std::runtime_error("Input arrays are not acceptable.");
4341

44-
if (x_train.get_typenum() != UAR_DOUBLE) {
45-
throw std::runtime_error("Expected a double precision FP array.");
42+
auto typenum = x_train.get_typenum();
43+
if (typenum == UAR_FLOAT) {
44+
sycl::event res_ev = knn_impl<float, unsigned int>(
45+
x_train.get_queue(), x_train.get_data<float>(),
46+
y_train.get_data<size_t>(), x_test.get_data<float>(), k,
47+
classes_num, train_size, test_size,
48+
predictions.get_data<unsigned int>(),
49+
votes_to_classes.get_data<float>(), data_dim);
50+
res_ev.wait();
51+
}
52+
else if (typenum == UAR_DOUBLE) {
53+
sycl::event res_ev = knn_impl<double, size_t>(
54+
x_train.get_queue(), x_train.get_data<double>(),
55+
y_train.get_data<size_t>(), x_test.get_data<double>(), k,
56+
classes_num, train_size, test_size, predictions.get_data<size_t>(),
57+
votes_to_classes.get_data<double>(), data_dim);
58+
res_ev.wait();
59+
}
60+
else {
61+
throw std::runtime_error(
62+
"Expected a double or single precision FP array.");
4663
}
47-
48-
sycl::event res_ev = knn_impl(
49-
x_train.get_queue(), x_train.get_data<double>(),
50-
y_train.get_data<size_t>(), x_test.get_data<double>(), k, classes_num,
51-
train_size, test_size, predictions.get_data<size_t>(),
52-
votes_to_classes.get_data<double>(), data_dim);
53-
res_ev.wait();
5464
}
5565

5666
PYBIND11_MODULE(_knn_sycl, m)

0 commit comments

Comments
 (0)