Skip to content

Commit 56d09ea

Browse files
committed
Merge branch 'develop' into SYCL_invocation
2 parents 5950201 + e10ed85 commit 56d09ea

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

include/plssvm/backends/SYCL/kernel/predict/work_group/predict_kernel.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class device_kernel_w_linear {
4545
* @param[in] grid_y_offset the offset in y-dimension into the data points if more than one execution grid has to be used
4646
*/
4747
device_kernel_w_linear(::sycl::handler &cgh, real_type *w_d, const real_type *alpha_d, const real_type *sv_d, const std::size_t num_classes, const std::size_t num_sv, const std::size_t device_specific_num_sv, const std::size_t sv_offset, const std::size_t grid_x_offset, const std::size_t grid_y_offset) :
48-
data_cache_feature_{ ::sycl::range<2>{ static_cast<std::size_t>(FEATURE_BLOCK_SIZE), static_cast<std::size_t>(INTERNAL_BLOCK_SIZE) * static_cast<std::size_t>(THREAD_BLOCK_SIZE) }, cgh },
49-
data_cache_alpha_{ ::sycl::range<2>{ static_cast<std::size_t>(FEATURE_BLOCK_SIZE), static_cast<std::size_t>(INTERNAL_BLOCK_SIZE) * static_cast<std::size_t>(THREAD_BLOCK_SIZE) }, cgh },
48+
data_cache_feature_{ ::sycl::range<2>{ static_cast<std::size_t>(THREAD_BLOCK_SIZE), static_cast<std::size_t>(INTERNAL_BLOCK_SIZE) * static_cast<std::size_t>(THREAD_BLOCK_SIZE) }, cgh },
49+
data_cache_alpha_{ ::sycl::range<2>{ static_cast<std::size_t>(THREAD_BLOCK_SIZE), static_cast<std::size_t>(INTERNAL_BLOCK_SIZE) * static_cast<std::size_t>(THREAD_BLOCK_SIZE) }, cgh },
5050
w_d_{ w_d },
5151
alpha_d_{ alpha_d },
5252
sv_d_{ sv_d },
@@ -93,8 +93,8 @@ class device_kernel_w_linear {
9393
const auto global_class_idx = class_idx_linear + static_cast<std::size_t>(internal) * THREAD_BLOCK_SIZE_uz;
9494
const auto global_feature_idx = feature_idx_linear + static_cast<std::size_t>(internal) * THREAD_BLOCK_SIZE_uz;
9595

96-
data_cache_feature_[local_id_0][internal * THREAD_BLOCK_SIZE + local_id_1] = sv_d_[global_feature_idx * (device_specific_num_sv_ + PADDING_SIZE_uz) + sv + sv_offset_ + threadIdx_x]; // SoA
97-
data_cache_alpha_[local_id_0][internal * THREAD_BLOCK_SIZE + local_id_1] = alpha_d_[global_class_idx * (num_sv_ + PADDING_SIZE_uz) + sv + threadIdx_x]; // AoS
96+
data_cache_feature_[local_id_0][internal * THREAD_BLOCK_SIZE + local_id_1] = sv_d_[global_feature_idx * (device_specific_num_sv_ + PADDING_SIZE_uz) + sv + threadIdx_x]; // SoA
97+
data_cache_alpha_[local_id_0][internal * THREAD_BLOCK_SIZE + local_id_1] = alpha_d_[global_class_idx * (num_sv_ + PADDING_SIZE_uz) + sv + sv_offset_ + threadIdx_x]; // AoS
9898
}
9999
nd_idx.barrier(); // wait until all work-items loaded their part of the data
100100

0 commit comments

Comments
 (0)