@@ -363,7 +363,7 @@ static void silu_sycl(const T *x, T *dst, const int k,
363363template <typename T>
364364static void sgn_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
365365 // hard code for now
366- const int num_blocks = (k + 256 - 1 ) / 256 ;
366+ const int num_blocks = ceil_div (k, 256 ) ;
367367 stream->parallel_for (
368368 sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range (1 , 1 , 256 )), sycl::range (1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
369369 sgn (x, dst, k, item_ct1);
@@ -373,7 +373,7 @@ static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
373373template <typename T>
374374static void abs_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
375375 // hard code for now
376- const int num_blocks = (k + 256 - 1 ) / 256 ;
376+ const int num_blocks = ceil_div (k, 256 ) ;
377377 stream->parallel_for (
378378 sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
379379 abs_op (x, dst, k, item_ct1);
@@ -384,7 +384,7 @@ static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
384384template <typename T>
385385static void elu_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
386386 // hard code for now
387- const int num_blocks = (k + 256 - 1 ) / 256 ;
387+ const int num_blocks = ceil_div (k, 256 ) ;
388388 stream->parallel_for (
389389 sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
390390 elu_op (x, dst, k, item_ct1);
0 commit comments