Skip to content

Commit a57273f

Browse files
dot func with strides impl (#930)
* dot func with strides impl * provide different impl by checking SYCL version * ckeck version of MKL for GEMM
1 parent 4910209 commit a57273f

File tree

5 files changed

+323
-93
lines changed

5 files changed

+323
-93
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 238 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "queue_sycl.hpp"
3535

3636
namespace mkl_blas = oneapi::mkl::blas;
37+
namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
3738
namespace mkl_lapack = oneapi::mkl::lapack;
3839

3940
template <typename _DataType, typename _ResultType>
@@ -75,6 +76,82 @@ void dpnp_astype_c(const void* array1_in, void* result1, const size_t size)
7576
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
7677
class dpnp_dot_c_kernel;
7778

79+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
80+
cl::sycl::event dot(cl::sycl::queue &queue,
81+
_DataType_output *result_out, _DataType_input1 *input1_in, _DataType_input2 *input2_in, size_t input1_strides, size_t input2_strides, size_t size,
82+
const cl::sycl::vector_class<cl::sycl::event> &dependencies = {})
83+
{
84+
(void)dependencies;
85+
86+
cl::sycl::event event;
87+
88+
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) &&
89+
std::is_same<_DataType_input2, _DataType_input1>::value &&
90+
std::is_same<_DataType_output, _DataType_input1>::value)
91+
{
92+
event = oneapi::mkl::blas::dot(queue,
93+
size,
94+
input1_in,
95+
input1_strides, // input1 stride
96+
input2_in,
97+
input2_strides, // input2 stride
98+
result_out);
99+
}
100+
else
101+
{
102+
#if LIBSYCL_VERSION_GREATER(5, 3, 0)
103+
event = queue.submit([&](sycl::handler &cgh)
104+
{
105+
cgh.parallel_for(sycl::range<1>{size},
106+
cl::sycl::reduction(result_out,
107+
std::plus<_DataType_output>(),
108+
cl::sycl::property::reduction::initialize_to_identity{}),
109+
[=](cl::sycl::id<1> idx, auto& sum)
110+
{
111+
sum += static_cast <_DataType_output>(input1_in[idx * input1_strides]) * static_cast <_DataType_output>(input2_in[idx * input2_strides]);
112+
});
113+
});
114+
// for some reason few such kernels cannot work in parallel
115+
// looks like a bug in level0 because with opencl works fine
116+
// that is why we call wait here
117+
event.wait();
118+
#else
119+
_DataType_output* local_mem =
120+
reinterpret_cast<_DataType_output*>(dpnp_memory_alloc_c(size * sizeof(_DataType_output)));
121+
122+
// what about reduction??
123+
cl::sycl::range<1> gws(size);
124+
125+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
126+
const size_t index = global_id[0];
127+
local_mem[index] = input1_in[index * input1_strides] * input2_in[index * input2_strides];
128+
};
129+
130+
auto kernel_func = [&](cl::sycl::handler& cgh) {
131+
cgh.parallel_for<class dpnp_dot_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
132+
gws, kernel_parallel_for_func);
133+
};
134+
135+
event = DPNP_QUEUE.submit(kernel_func);
136+
137+
event.wait();
138+
139+
auto policy = oneapi::dpl::execution::make_device_policy<
140+
class dpnp_dot_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
141+
142+
_DataType_output accumulator = 0;
143+
accumulator =
144+
std::reduce(policy, local_mem, local_mem + size, _DataType_output(0), std::plus<_DataType_output>());
145+
policy.queue().wait();
146+
147+
dpnp_memory_memcpy_c(result_out, &accumulator, sizeof(_DataType_output)); // result[0] = accumulator;
148+
149+
free(local_mem, DPNP_QUEUE);
150+
#endif
151+
}
152+
return event;
153+
}
154+
78155
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
79156
void dpnp_dot_c(void* result_out,
80157
const size_t result_size,
@@ -92,78 +169,195 @@ void dpnp_dot_c(void* result_out,
92169
const size_t* input2_shape,
93170
const size_t* input2_strides)
94171
{
95-
(void)input1_shape;
96-
(void)input1_ndim;
97-
(void)input2_shape;
98-
(void)input2_ndim;
99-
100-
(void)result_size;
101-
(void)result_ndim;
102-
(void)result_shape;
103172
(void)result_strides;
104-
(void)input1_strides;
105-
(void)input2_strides;
106173

107-
cl::sycl::event event;
108174
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
109175
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size);
110176

111177
_DataType_input1* input1 = input1_ptr.get_ptr();
112178
_DataType_input2* input2 = input2_ptr.get_ptr();
113179
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
114180

115-
if (!input1_size)
181+
if (!input1_size || !input2_size)
116182
{
183+
_DataType_output val = _DataType_output(0);
184+
dpnp_initval_c<_DataType_output>(result, &val, result_size);
117185
return;
118186
}
119187

120-
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) &&
121-
std::is_same<_DataType_input2, _DataType_input1>::value &&
122-
std::is_same<_DataType_output, _DataType_input1>::value)
188+
// scalar
189+
if ((input1_ndim == 0) || (input2_ndim == 0))
123190
{
124-
event = mkl_blas::dot(DPNP_QUEUE,
125-
input1_size,
126-
input1,
127-
1, // input1 stride
128-
input2,
129-
1, // input2 stride
130-
result);
191+
// there is no support of strides in multiply function
192+
// so result can be wrong if input array has non-standard (c-contiguous) strides
193+
dpnp_multiply_c<_DataType_output, _DataType_input1, _DataType_input2>(result, \
194+
input1_in,
195+
input1_size,
196+
input1_shape,
197+
input1_ndim,
198+
input2_in,
199+
input2_size,
200+
input2_shape,
201+
input2_ndim,
202+
NULL);
203+
return;
204+
}
205+
206+
// if both arrays are vectors
207+
if ((input1_ndim == 1) && (input2_ndim == 1))
208+
{
209+
assert(input1_size == input2_size);
210+
cl::sycl::event event = dot(DPNP_QUEUE, result, input1, input2, input1_strides[0], input2_strides[0], input1_size);
131211
event.wait();
212+
return;
213+
}
214+
215+
// 1D vector
216+
size_t ext_input1_ndim = input1_ndim == 1 ? 2 : input1_ndim;
217+
size_t* ext_input1_shape = new size_t[ext_input1_ndim];
218+
size_t* ext_input1_strides = new size_t[ext_input1_ndim];
219+
if (input1_ndim == 1)
220+
{
221+
ext_input1_shape[0] = 1;
222+
ext_input1_shape[1] = input1_shape[0];
223+
ext_input1_strides[0] = 0;
224+
ext_input1_strides[1] = input1_strides[0];
132225
}
133226
else
134227
{
135-
_DataType_output* local_mem =
136-
reinterpret_cast<_DataType_output*>(dpnp_memory_alloc_c(input1_size * sizeof(_DataType_output)));
228+
for(size_t i = 0; i < ext_input1_ndim; ++i)
229+
{
230+
ext_input1_shape[i] = input1_shape[i];
231+
ext_input1_strides[i] = input1_strides[i];
232+
}
233+
}
234+
size_t ext_input2_ndim = input2_ndim == 1 ? 2 : input2_ndim;
235+
size_t* ext_input2_shape = new size_t[ext_input2_ndim];
236+
size_t* ext_input2_strides = new size_t[ext_input2_ndim];
237+
if (input2_ndim == 1)
238+
{
239+
ext_input2_shape[0] = input2_shape[0];
240+
ext_input2_shape[1] = 1;
241+
ext_input2_strides[0] = input2_strides[0];
242+
ext_input2_strides[1] = 0;
243+
}
244+
else
245+
{
246+
for(size_t i = 0; i < ext_input2_ndim; ++i)
247+
{
248+
ext_input2_shape[i] = input2_shape[i];
249+
ext_input2_strides[i] = input2_strides[i];
250+
}
251+
}
252+
size_t ext_result_ndim = ((input1_ndim == 1) || (input2_ndim == 1)) ? 2 : result_ndim;
253+
size_t* ext_result_shape = new size_t[ext_result_ndim];
254+
if ((input1_ndim == 1) || (input2_ndim == 1))
255+
{
256+
ext_result_shape[0] = ext_input1_shape[0];
257+
ext_result_shape[1] = ext_input2_shape[1];
258+
}
259+
else
260+
{
261+
for(size_t i = 0; i < ext_result_ndim; ++i)
262+
{
263+
ext_result_shape[i] = result_shape[i];
264+
}
265+
}
137266

138-
// what about reduction??
139-
cl::sycl::range<1> gws(input1_size);
267+
// check if GEMM can be executed (types)
268+
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) &&
269+
std::is_same<_DataType_input2, _DataType_input1>::value &&
270+
std::is_same<_DataType_output, _DataType_input1>::value)
271+
{
272+
// check if GEMM can be executed (strides)
273+
// TODO: rewrite the condition in general case for ndims > 2
274+
// (looks like there are such another cases)
275+
if ((ext_input1_ndim == 2 && ext_input2_ndim == 2) &&
276+
(ext_input1_strides[0] == 1 || ext_input1_strides[1] == 1) &&
277+
(ext_input2_strides[0] == 1 || ext_input2_strides[1] == 1)
278+
)
279+
{
280+
// there is a difference of behavior with trans and sizes params in previous version of GEMM
281+
// only new version is supported, in case of old version computation goes in common way
282+
#if INTEL_MKL_VERSION >= 20210004
283+
oneapi::mkl::transpose trans1 = ext_input1_strides[0] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
284+
oneapi::mkl::transpose trans2 = ext_input2_strides[0] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
285+
286+
const size_t size_m = ext_input1_shape[0];
287+
const size_t size_n = ext_input2_shape[1];
288+
const size_t size_k = ext_input1_shape[1];
289+
290+
const std::int64_t lda = trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0] : ext_input1_strides[1];
291+
const std::int64_t ldb = trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0] : ext_input2_strides[1];;
292+
// defenition of ldc will be another for result with non-standard (c-contiguous) strides
293+
// const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
294+
const std::int64_t ldc = size_n;
295+
296+
cl::sycl::event event = mkl_blas_rm::gemm(DPNP_QUEUE,
297+
trans1,
298+
trans2,
299+
size_m,
300+
size_n,
301+
size_k,
302+
_DataType_output(1), // alpha
303+
input1,
304+
lda,
305+
input2,
306+
ldb,
307+
_DataType_output(0), // beta
308+
result,
309+
ldc);
310+
event.wait();
311+
return;
312+
#endif
313+
}
314+
}
140315

141-
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
142-
const size_t index = global_id[0];
143-
local_mem[index] = input1[index] * input2[index];
144-
};
316+
// deprecated? can be replaced with std::vector<cl::sycl::event>
317+
cl::sycl::vector_class<cl::sycl::event> dot_events;
318+
// std::vector<cl::sycl::event> dot_events;
319+
dot_events.reserve(result_size);
145320

146-
auto kernel_func = [&](cl::sycl::handler& cgh) {
147-
cgh.parallel_for<class dpnp_dot_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
148-
gws, kernel_parallel_for_func);
149-
};
321+
size_t dot_st1 = ext_input1_strides[ext_input1_ndim - 1];
322+
size_t dot_st2 = ext_input2_strides[ext_input2_ndim - 2];
323+
size_t dot_size = ext_input1_shape[ext_input1_ndim - 1];
150324

151-
event = DPNP_QUEUE.submit(kernel_func);
325+
size_t* res_coords = new size_t[ext_result_ndim];
326+
size_t* result_offsets = new size_t[ext_result_ndim];
327+
get_shape_offsets_inkernel(ext_result_shape, ext_result_ndim, result_offsets);
152328

153-
event.wait();
329+
for(size_t i = 0; i < result_size; ++i)
330+
{
331+
get_xyz_by_id(i, ext_result_ndim, result_offsets, res_coords);
154332

155-
auto policy = oneapi::dpl::execution::make_device_policy<
156-
class dpnp_dot_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
333+
_DataType_output* dot_res = result + i;
157334

158-
_DataType_output accumulator = 0;
159-
accumulator =
160-
std::reduce(policy, local_mem, local_mem + input1_size, _DataType_output(0), std::plus<_DataType_output>());
161-
policy.queue().wait();
335+
_DataType_input1* dot_in1 = input1;
336+
for(size_t j = 0; j < ext_input1_ndim - 1; ++j)
337+
{
338+
dot_in1 = dot_in1 + res_coords[j] * ext_input1_strides[j];
339+
}
162340

163-
dpnp_memory_memcpy_c(result, &accumulator, sizeof(_DataType_output)); // result[0] = accumulator;
341+
_DataType_input2* dot_in2 = input2;
342+
for(size_t j = 0; j < ext_input2_ndim - 2; ++j)
343+
{
344+
dot_in2 = dot_in2 + res_coords[ext_input1_ndim - 1 + j] * ext_input2_strides[j];
345+
}
346+
dot_in2 = dot_in2 + res_coords[ext_input1_ndim + ext_input2_ndim - 3] * ext_input2_strides[ext_input2_ndim - 1];
164347

165-
free(local_mem, DPNP_QUEUE);
348+
dot_events.push_back(dot(DPNP_QUEUE, dot_res, dot_in1, dot_in2, dot_st1, dot_st2, dot_size));
166349
}
350+
351+
sycl::event::wait(dot_events);
352+
353+
delete[] res_coords;
354+
delete[] result_offsets;
355+
delete[] ext_input1_shape;
356+
delete[] ext_input1_strides;
357+
delete[] ext_input2_shape;
358+
delete[] ext_input2_strides;
359+
delete[] ext_result_shape;
360+
167361
}
168362

169363
template <typename _DataType, typename _ResultType>

dpnp/backend/kernels/dpnp_krnl_statistics.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ void dpnp_correlate_c(void* result_out,
5151
{
5252
(void)where;
5353

54+
size_t dummy[] = {1};
5455
dpnp_dot_c<_DataType_output, _DataType_input1, _DataType_input2>(result_out,
5556
42, // dummy result_size
5657
42, // dummy result_ndim
@@ -60,12 +61,12 @@ void dpnp_correlate_c(void* result_out,
6061
input1_size,
6162
input1_shape_ndim,
6263
input1_shape,
63-
NULL, // dummy input1_strides
64+
dummy, // dummy input1_strides
6465
input2_in,
6566
input2_size,
6667
input2_shape_ndim,
6768
input2_shape,
68-
NULL); // dummy input2_strides
69+
dummy); // dummy input2_strides
6970

7071
return;
7172
}

0 commit comments

Comments
 (0)