33// SPDX-License-Identifier: Apache-2.0
44
55#include " _gaussian_kernel.hpp"
6- #include < CL/sycl.hpp>
76#include < dpctl4pybind11.hpp>
87
98template <typename ... Args> bool ensure_compatibility (const Args &...args)
@@ -32,18 +31,14 @@ template <typename... Args> bool ensure_compatibility(const Args &...args)
3231 return true ;
3332}
3433
35- void gaussian_sync (dpctl::tensor::usm_ndarray a,
36- dpctl::tensor::usm_ndarray b,
37- dpctl::tensor::usm_ndarray m,
38- int size,
39- int block_sizeXY,
40- dpctl::tensor::usm_ndarray result)
34+ template <typename FpTy>
35+ void gaussian_handler (FpTy *m_device,
36+ FpTy *a_device,
37+ FpTy *b_device,
38+ FpTy *result,
39+ int size,
40+ int block_sizeXY)
4141{
42- if (!ensure_compatibility (a, m, b, result))
43- throw std::runtime_error (" Input arrays are not acceptable." );
44-
45- int t;
46-
4742 sycl::queue q_ct1;
4843
4944 int block_size, grid_size;
@@ -60,12 +55,7 @@ void gaussian_sync(dpctl::tensor::usm_ndarray a,
6055
6156 sycl::range<3 > dimBlockXY (1 , blocksize2d, blocksize2d);
6257 sycl::range<3 > dimGridXY (1 , gridsize2d, gridsize2d);
63-
64- auto a_value = a.get_data <double >();
65- auto b_value = b.get_data <double >();
66- auto m_value = m.get_data <double >();
67-
68- for (t = 0 ; t < (size - 1 ); t++) {
58+ for (int t = 0 ; t < (size - 1 ); t++) {
6959 /*
7060 DPCT1049:7: The workgroup size passed to the SYCL kernel may
7161 exceed the limit. To get the device limit, query
@@ -76,7 +66,7 @@ void gaussian_sync(dpctl::tensor::usm_ndarray a,
7666 auto size_ct2 = size;
7767 cgh.parallel_for (sycl::nd_range<3 >(dimGrid * dimBlock, dimBlock),
7868 [=](sycl::nd_item<3 > item_ct1) {
79- gaussian_kernel_1 (m_value, a_value , size_ct2,
69+ gaussian_kernel_1 (m_device, a_device , size_ct2,
8070 t, item_ct1);
8171 });
8272 });
@@ -94,28 +84,52 @@ void gaussian_sync(dpctl::tensor::usm_ndarray a,
9484 cgh.parallel_for (
9585 sycl::nd_range<3 >(dimGridXY * dimBlockXY, dimBlockXY),
9686 [=](sycl::nd_item<3 > item_ct1) {
97- gaussian_kernel_2 (m_value, a_value, b_value , size_ct3,
87+ gaussian_kernel_2 (m_device, a_device, b_device , size_ct3,
9888 size_t_ct4, t, item_ct1);
9989 });
10090 });
10191 q_ct1.wait_and_throw ();
10292 }
103- // Copying the final answer
104- auto result_value = result.get_data <double >();
10593
10694 for (int i = 0 ; i < size; i++) {
10795
108- result_value [size - i - 1 ] = b_value [size - i - 1 ];
96+ result [size - i - 1 ] = b_device [size - i - 1 ];
10997
11098 for (int j = 0 ; j < i; j++) {
111- result_value [size - i - 1 ] -=
112- *(a_value + size * (size - i - 1 ) + (size - j - 1 )) *
113- result_value [size - j - 1 ];
99+ result [size - i - 1 ] -=
100+ *(a_device + size * (size - i - 1 ) + (size - j - 1 )) *
101+ result [size - j - 1 ];
114102 }
115103
116- result_value[size - i - 1 ] =
117- result_value[size - i - 1 ] /
118- *(a_value + size * (size - i - 1 ) + (size - i - 1 ));
104+ result[size - i - 1 ] =
105+ result[size - i - 1 ] /
106+ *(a_device + size * (size - i - 1 ) + (size - i - 1 ));
107+ }
108+ }
109+
110+ void gaussian_sync (dpctl::tensor::usm_ndarray a,
111+ dpctl::tensor::usm_ndarray b,
112+ dpctl::tensor::usm_ndarray m,
113+ int size,
114+ int block_sizeXY,
115+ dpctl::tensor::usm_ndarray result)
116+ {
117+ if (!ensure_compatibility (a, m, b, result))
118+ throw std::runtime_error (" Input arrays are not acceptable." );
119+
120+ if (a.get_typenum () == UAR_DOUBLE) {
121+ gaussian_handler (m.get_data <double >(), a.get_data <double >(),
122+ b.get_data <double >(), result.get_data <double >(), size,
123+ block_sizeXY);
124+ }
125+ else if (a.get_typenum () == UAR_FLOAT) {
126+ gaussian_handler (m.get_data <float >(), a.get_data <float >(),
127+ b.get_data <float >(), result.get_data <float >(), size,
128+ block_sizeXY);
129+ }
130+ else {
131+ throw std::runtime_error (
132+ " Expected a double or single precision FP array." );
119133 }
120134}
121135
0 commit comments