77
88template <typename FpTy, typename IntTy> class theKernel ;
99
10- template <typename FpTy> struct neighbors
10+ template <typename FpTy, typename IntTy > struct neighbors
1111{
1212 FpTy dist;
13- size_t label;
13+ IntTy label;
1414};
1515
1616template <typename FpTy, typename IntTy>
1717sycl::event knn_impl (sycl::queue q,
1818 FpTy *d_train,
19- size_t *d_train_labels,
19+ IntTy *d_train_labels,
2020 FpTy *d_test,
2121 size_t k,
2222 size_t classes_num,
@@ -33,7 +33,7 @@ sycl::event knn_impl(sycl::queue q,
3333
3434 // here k has to be 5 in order to match with numpy no. of
3535 // neighbors
36- struct neighbors <FpTy> queue_neighbors[5 ];
36+ struct neighbors <FpTy, IntTy > queue_neighbors[5 ];
3737
3838 // count distances
3939 for (size_t j = 0 ; j < k; ++j) {
@@ -54,7 +54,7 @@ sycl::event knn_impl(sycl::queue q,
5454 for (size_t j = 0 ; j < k; ++j) {
5555 // push queue
5656 FpTy new_distance = queue_neighbors[j].dist ;
57- FpTy new_neighbor_label = queue_neighbors[j].label ;
57+ IntTy new_neighbor_label = queue_neighbors[j].label ;
5858 size_t index = j;
5959 while (index > 0 &&
6060 new_distance < queue_neighbors[index - 1 ].dist )
@@ -83,7 +83,7 @@ sycl::event knn_impl(sycl::queue q,
8383
8484 // push queue
8585 FpTy new_distance = queue_neighbors[k - 1 ].dist ;
86- FpTy new_neighbor_label = queue_neighbors[k - 1 ].label ;
86+ IntTy new_neighbor_label = queue_neighbors[k - 1 ].label ;
8787 size_t index = k - 1 ;
8888
8989 while (index > 0 &&
0 commit comments