34
34
#include " queue_sycl.hpp"
35
35
36
36
namespace mkl_blas = oneapi::mkl::blas;
37
+ namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
37
38
namespace mkl_lapack = oneapi::mkl::lapack;
38
39
39
40
template <typename _DataType, typename _ResultType>
@@ -75,6 +76,82 @@ void dpnp_astype_c(const void* array1_in, void* result1, const size_t size)
75
76
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
76
77
class dpnp_dot_c_kernel ;
77
78
79
+ template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
80
+ cl::sycl::event dot (cl::sycl::queue &queue,
81
+ _DataType_output *result_out, _DataType_input1 *input1_in, _DataType_input2 *input2_in, size_t input1_strides, size_t input2_strides, size_t size,
82
+ const cl::sycl::vector_class<cl::sycl::event> &dependencies = {})
83
+ {
84
+ (void )dependencies;
85
+
86
+ cl::sycl::event event;
87
+
88
+ if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
89
+ std::is_same<_DataType_input2, _DataType_input1>::value &&
90
+ std::is_same<_DataType_output, _DataType_input1>::value)
91
+ {
92
+ event = oneapi::mkl::blas::dot (queue,
93
+ size,
94
+ input1_in,
95
+ input1_strides, // input1 stride
96
+ input2_in,
97
+ input2_strides, // input2 stride
98
+ result_out);
99
+ }
100
+ else
101
+ {
102
+ #if LIBSYCL_VERSION_GREATER(5, 3, 0)
103
+ event = queue.submit ([&](sycl::handler &cgh)
104
+ {
105
+ cgh.parallel_for (sycl::range<1 >{size},
106
+ cl::sycl::reduction (result_out,
107
+ std::plus<_DataType_output>(),
108
+ cl::sycl::property::reduction::initialize_to_identity{}),
109
+ [=](cl::sycl::id<1 > idx, auto & sum)
110
+ {
111
+ sum += static_cast <_DataType_output>(input1_in[idx * input1_strides]) * static_cast <_DataType_output>(input2_in[idx * input2_strides]);
112
+ });
113
+ });
114
+ // for some reason few such kernels cannot work in parallel
115
+ // looks like a bug in level0 because with opencl works fine
116
+ // that is why we call wait here
117
+ event.wait ();
118
+ #else
119
+ _DataType_output* local_mem =
120
+ reinterpret_cast <_DataType_output*>(dpnp_memory_alloc_c (size * sizeof (_DataType_output)));
121
+
122
+ // what about reduction??
123
+ cl::sycl::range<1 > gws (size);
124
+
125
+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
126
+ const size_t index = global_id[0 ];
127
+ local_mem[index] = input1_in[index * input1_strides] * input2_in[index * input2_strides];
128
+ };
129
+
130
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
131
+ cgh.parallel_for <class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(
132
+ gws, kernel_parallel_for_func);
133
+ };
134
+
135
+ event = DPNP_QUEUE.submit (kernel_func);
136
+
137
+ event.wait ();
138
+
139
+ auto policy = oneapi::dpl::execution::make_device_policy<
140
+ class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
141
+
142
+ _DataType_output accumulator = 0 ;
143
+ accumulator =
144
+ std::reduce (policy, local_mem, local_mem + size, _DataType_output (0 ), std::plus<_DataType_output>());
145
+ policy.queue ().wait ();
146
+
147
+ dpnp_memory_memcpy_c (result_out, &accumulator, sizeof (_DataType_output)); // result[0] = accumulator;
148
+
149
+ free (local_mem, DPNP_QUEUE);
150
+ #endif
151
+ }
152
+ return event;
153
+ }
154
+
78
155
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
79
156
void dpnp_dot_c (void * result_out,
80
157
const size_t result_size,
@@ -92,78 +169,195 @@ void dpnp_dot_c(void* result_out,
92
169
const size_t * input2_shape,
93
170
const size_t * input2_strides)
94
171
{
95
- (void )input1_shape;
96
- (void )input1_ndim;
97
- (void )input2_shape;
98
- (void )input2_ndim;
99
-
100
- (void )result_size;
101
- (void )result_ndim;
102
- (void )result_shape;
103
172
(void )result_strides;
104
- (void )input1_strides;
105
- (void )input2_strides;
106
173
107
- cl::sycl::event event;
108
174
DPNPC_ptr_adapter<_DataType_input1> input1_ptr (input1_in, input1_size);
109
175
DPNPC_ptr_adapter<_DataType_input2> input2_ptr (input2_in, input2_size);
110
176
111
177
_DataType_input1* input1 = input1_ptr.get_ptr ();
112
178
_DataType_input2* input2 = input2_ptr.get_ptr ();
113
179
_DataType_output* result = reinterpret_cast <_DataType_output*>(result_out);
114
180
115
- if (!input1_size)
181
+ if (!input1_size || !input2_size )
116
182
{
183
+ _DataType_output val = _DataType_output (0 );
184
+ dpnp_initval_c<_DataType_output>(result, &val, result_size);
117
185
return ;
118
186
}
119
187
120
- if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
121
- std::is_same<_DataType_input2, _DataType_input1>::value &&
122
- std::is_same<_DataType_output, _DataType_input1>::value)
188
+ // scalar
189
+ if ((input1_ndim == 0 ) || (input2_ndim == 0 ))
123
190
{
124
- event = mkl_blas::dot (DPNP_QUEUE,
125
- input1_size,
126
- input1,
127
- 1 , // input1 stride
128
- input2,
129
- 1 , // input2 stride
130
- result);
191
+ // there is no support of strides in multiply function
192
+ // so result can be wrong if input array has non-standard (c-contiguous) strides
193
+ dpnp_multiply_c<_DataType_output, _DataType_input1, _DataType_input2>(result, \
194
+ input1_in,
195
+ input1_size,
196
+ input1_shape,
197
+ input1_ndim,
198
+ input2_in,
199
+ input2_size,
200
+ input2_shape,
201
+ input2_ndim,
202
+ NULL );
203
+ return ;
204
+ }
205
+
206
+ // if both arrays are vectors
207
+ if ((input1_ndim == 1 ) && (input2_ndim == 1 ))
208
+ {
209
+ assert (input1_size == input2_size);
210
+ cl::sycl::event event = dot (DPNP_QUEUE, result, input1, input2, input1_strides[0 ], input2_strides[0 ], input1_size);
131
211
event.wait ();
212
+ return ;
213
+ }
214
+
215
+ // 1D vector
216
+ size_t ext_input1_ndim = input1_ndim == 1 ? 2 : input1_ndim;
217
+ size_t * ext_input1_shape = new size_t [ext_input1_ndim];
218
+ size_t * ext_input1_strides = new size_t [ext_input1_ndim];
219
+ if (input1_ndim == 1 )
220
+ {
221
+ ext_input1_shape[0 ] = 1 ;
222
+ ext_input1_shape[1 ] = input1_shape[0 ];
223
+ ext_input1_strides[0 ] = 0 ;
224
+ ext_input1_strides[1 ] = input1_strides[0 ];
132
225
}
133
226
else
134
227
{
135
- _DataType_output* local_mem =
136
- reinterpret_cast <_DataType_output*>(dpnp_memory_alloc_c (input1_size * sizeof (_DataType_output)));
228
+ for (size_t i = 0 ; i < ext_input1_ndim; ++i)
229
+ {
230
+ ext_input1_shape[i] = input1_shape[i];
231
+ ext_input1_strides[i] = input1_strides[i];
232
+ }
233
+ }
234
+ size_t ext_input2_ndim = input2_ndim == 1 ? 2 : input2_ndim;
235
+ size_t * ext_input2_shape = new size_t [ext_input2_ndim];
236
+ size_t * ext_input2_strides = new size_t [ext_input2_ndim];
237
+ if (input2_ndim == 1 )
238
+ {
239
+ ext_input2_shape[0 ] = input2_shape[0 ];
240
+ ext_input2_shape[1 ] = 1 ;
241
+ ext_input2_strides[0 ] = input2_strides[0 ];
242
+ ext_input2_strides[1 ] = 0 ;
243
+ }
244
+ else
245
+ {
246
+ for (size_t i = 0 ; i < ext_input2_ndim; ++i)
247
+ {
248
+ ext_input2_shape[i] = input2_shape[i];
249
+ ext_input2_strides[i] = input2_strides[i];
250
+ }
251
+ }
252
+ size_t ext_result_ndim = ((input1_ndim == 1 ) || (input2_ndim == 1 )) ? 2 : result_ndim;
253
+ size_t * ext_result_shape = new size_t [ext_result_ndim];
254
+ if ((input1_ndim == 1 ) || (input2_ndim == 1 ))
255
+ {
256
+ ext_result_shape[0 ] = ext_input1_shape[0 ];
257
+ ext_result_shape[1 ] = ext_input2_shape[1 ];
258
+ }
259
+ else
260
+ {
261
+ for (size_t i = 0 ; i < ext_result_ndim; ++i)
262
+ {
263
+ ext_result_shape[i] = result_shape[i];
264
+ }
265
+ }
137
266
138
- // what about reduction??
139
- cl::sycl::range<1 > gws (input1_size);
267
+ // check if GEMM can be executed (types)
268
+ if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
269
+ std::is_same<_DataType_input2, _DataType_input1>::value &&
270
+ std::is_same<_DataType_output, _DataType_input1>::value)
271
+ {
272
+ // check if GEMM can be executed (strides)
273
+ // TODO: rewrite the condition in general case for ndims > 2
274
+ // (looks like there are such another cases)
275
+ if ((ext_input1_ndim == 2 && ext_input2_ndim == 2 ) &&
276
+ (ext_input1_strides[0 ] == 1 || ext_input1_strides[1 ] == 1 ) &&
277
+ (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 )
278
+ )
279
+ {
280
+ // there is a difference of behavior with trans and sizes params in previous version of GEMM
281
+ // only new version is supported, in case of old version computation goes in common way
282
+ #if INTEL_MKL_VERSION >= 20210004
283
+ oneapi::mkl::transpose trans1 = ext_input1_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
284
+ oneapi::mkl::transpose trans2 = ext_input2_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
285
+
286
+ const size_t size_m = ext_input1_shape[0 ];
287
+ const size_t size_n = ext_input2_shape[1 ];
288
+ const size_t size_k = ext_input1_shape[1 ];
289
+
290
+ const std::int64_t lda = trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0 ] : ext_input1_strides[1 ];
291
+ const std::int64_t ldb = trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0 ] : ext_input2_strides[1 ];;
292
+ // defenition of ldc will be another for result with non-standard (c-contiguous) strides
293
+ // const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
294
+ const std::int64_t ldc = size_n;
295
+
296
+ cl::sycl::event event = mkl_blas_rm::gemm (DPNP_QUEUE,
297
+ trans1,
298
+ trans2,
299
+ size_m,
300
+ size_n,
301
+ size_k,
302
+ _DataType_output (1 ), // alpha
303
+ input1,
304
+ lda,
305
+ input2,
306
+ ldb,
307
+ _DataType_output (0 ), // beta
308
+ result,
309
+ ldc);
310
+ event.wait ();
311
+ return ;
312
+ #endif
313
+ }
314
+ }
140
315
141
- auto kernel_parallel_for_func = [=]( cl::sycl::id< 1 > global_id) {
142
- const size_t index = global_id[ 0 ] ;
143
- local_mem[index] = input1[index] * input2[index] ;
144
- } ;
316
+ // deprecated? can be replaced with std::vector< cl::sycl::event>
317
+ cl::sycl::vector_class<cl::sycl::event> dot_events ;
318
+ // std::vector<cl::sycl::event> dot_events ;
319
+ dot_events. reserve (result_size) ;
145
320
146
- auto kernel_func = [&](cl::sycl::handler& cgh) {
147
- cgh.parallel_for <class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(
148
- gws, kernel_parallel_for_func);
149
- };
321
+ size_t dot_st1 = ext_input1_strides[ext_input1_ndim - 1 ];
322
+ size_t dot_st2 = ext_input2_strides[ext_input2_ndim - 2 ];
323
+ size_t dot_size = ext_input1_shape[ext_input1_ndim - 1 ];
150
324
151
- event = DPNP_QUEUE.submit (kernel_func);
325
+ size_t * res_coords = new size_t [ext_result_ndim];
326
+ size_t * result_offsets = new size_t [ext_result_ndim];
327
+ get_shape_offsets_inkernel (ext_result_shape, ext_result_ndim, result_offsets);
152
328
153
- event.wait ();
329
+ for (size_t i = 0 ; i < result_size; ++i)
330
+ {
331
+ get_xyz_by_id (i, ext_result_ndim, result_offsets, res_coords);
154
332
155
- auto policy = oneapi::dpl::execution::make_device_policy<
156
- class dpnp_dot_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
333
+ _DataType_output* dot_res = result + i;
157
334
158
- _DataType_output accumulator = 0 ;
159
- accumulator =
160
- std::reduce (policy, local_mem, local_mem + input1_size, _DataType_output (0 ), std::plus<_DataType_output>());
161
- policy.queue ().wait ();
335
+ _DataType_input1* dot_in1 = input1;
336
+ for (size_t j = 0 ; j < ext_input1_ndim - 1 ; ++j)
337
+ {
338
+ dot_in1 = dot_in1 + res_coords[j] * ext_input1_strides[j];
339
+ }
162
340
163
- dpnp_memory_memcpy_c (result, &accumulator, sizeof (_DataType_output)); // result[0] = accumulator;
341
+ _DataType_input2* dot_in2 = input2;
342
+ for (size_t j = 0 ; j < ext_input2_ndim - 2 ; ++j)
343
+ {
344
+ dot_in2 = dot_in2 + res_coords[ext_input1_ndim - 1 + j] * ext_input2_strides[j];
345
+ }
346
+ dot_in2 = dot_in2 + res_coords[ext_input1_ndim + ext_input2_ndim - 3 ] * ext_input2_strides[ext_input2_ndim - 1 ];
164
347
165
- free (local_mem, DPNP_QUEUE );
348
+ dot_events. push_back ( dot (DPNP_QUEUE, dot_res, dot_in1, dot_in2, dot_st1, dot_st2, dot_size) );
166
349
}
350
+
351
+ sycl::event::wait (dot_events);
352
+
353
+ delete[] res_coords;
354
+ delete[] result_offsets;
355
+ delete[] ext_input1_shape;
356
+ delete[] ext_input1_strides;
357
+ delete[] ext_input2_shape;
358
+ delete[] ext_input2_strides;
359
+ delete[] ext_result_shape;
360
+
167
361
}
168
362
169
363
template <typename _DataType, typename _ResultType>
0 commit comments