Skip to content

Commit b0f6461

Browse files
Linalg kron func (#507)
* kron func impl
1 parent d206e0f commit b0f6461

File tree

8 files changed

+168
-11
lines changed

8 files changed

+168
-11
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ enum class DPNPFuncName : size_t
9898
DPNP_FN_FLOOR_DIVIDE, /**< Used in numpy.floor_divide() implementation */
9999
DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */
100100
DPNP_FN_HYPOT, /**< Used in numpy.hypot() implementation */
101-
DPNP_FN_INVERT, /**< Used in numpy.invert() implementation */
102101
DPNP_FN_INV, /**< Used in numpy.linalg.inv() implementation */
102+
DPNP_FN_INVERT, /**< Used in numpy.invert() implementation */
103+
DPNP_FN_KRON, /**< Used in numpy.kron() implementation */
103104
DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() implementation */
104105
DPNP_FN_LOG, /**< Used in numpy.log() implementation */
105106
DPNP_FN_LOG10, /**< Used in numpy.log10() implementation */

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,73 @@ void dpnp_inv_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
275275
return;
276276
}
277277

278+
template <typename _DataType1, typename _DataType2, typename _ResultType>
279+
class dpnp_kron_c_kernel;
280+
281+
template <typename _DataType1, typename _DataType2, typename _ResultType>
282+
void dpnp_kron_c(void* array1_in,
283+
void* array2_in,
284+
void* result1,
285+
size_t* in1_shape,
286+
size_t* in2_shape,
287+
size_t* res_shape,
288+
size_t ndim)
289+
{
290+
_DataType1* array1 = reinterpret_cast<_DataType1*>(array1_in);
291+
_DataType2* array2 = reinterpret_cast<_DataType2*>(array2_in);
292+
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
293+
294+
size_t size = 1;
295+
for (size_t i = 0; i < ndim; ++i)
296+
{
297+
size *= res_shape[i];
298+
}
299+
300+
size_t* _in1_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
301+
size_t* _in2_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
302+
303+
dpnp_memory_memcpy_c(_in1_shape, in1_shape, ndim * sizeof(size_t));
304+
dpnp_memory_memcpy_c(_in2_shape, in2_shape, ndim * sizeof(size_t));
305+
306+
size_t* in1_offsets = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
307+
size_t* in2_offsets = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
308+
size_t* res_offsets = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
309+
310+
get_shape_offsets_inkernel<size_t>(in1_shape, ndim, in1_offsets);
311+
get_shape_offsets_inkernel<size_t>(in2_shape, ndim, in2_offsets);
312+
get_shape_offsets_inkernel<size_t>(res_shape, ndim, res_offsets);
313+
314+
cl::sycl::range<1> gws(size);
315+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
316+
const size_t idx = global_id[0];
317+
318+
size_t idx1 = 0;
319+
size_t idx2 = 0;
320+
size_t reminder = idx;
321+
for (size_t axis = 0; axis < ndim; ++axis)
322+
{
323+
const size_t res_axis = reminder / res_offsets[axis];
324+
reminder = reminder - res_axis * res_offsets[axis];
325+
326+
const size_t in1_axis = res_axis / _in2_shape[axis];
327+
const size_t in2_axis = res_axis - in1_axis * _in2_shape[axis];
328+
329+
idx1 += in1_axis * in1_offsets[axis];
330+
idx2 += in2_axis * in2_offsets[axis];
331+
}
332+
333+
result[idx] = array1[idx1] * array2[idx2];
334+
};
335+
336+
auto kernel_func = [&](cl::sycl::handler& cgh) {
337+
cgh.parallel_for<class dpnp_kron_c_kernel<_DataType1, _DataType2, _ResultType>>(gws, kernel_parallel_for_func);
338+
};
339+
340+
cl::sycl::event event = DPNP_QUEUE.submit(kernel_func);
341+
342+
event.wait();
343+
}
344+
278345
template <typename _DataType>
279346
class dpnp_matrix_rank_c_kernel;
280347

@@ -379,6 +446,41 @@ void func_map_init_linalg_func(func_map_t& fmap)
379446
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_inv_c<float>};
380447
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_inv_c<double>};
381448

449+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_kron_c<int, int, int>};
450+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_kron_c<int, long, long>};
451+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_kron_c<int, float, float>};
452+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_kron_c<int, double, double>};
453+
// fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_C128] = {
454+
// eft_C128, (void*)dpnp_kron_c<int, std::complex<double>, std::complex<double>>};
455+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_INT] = {eft_LNG, (void*)dpnp_kron_c<long, int, long>};
456+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_kron_c<long, long, long>};
457+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_kron_c<long, float, float>};
458+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_kron_c<long, double, double>};
459+
// fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_C128] = {
460+
// eft_C128, (void*)dpnp_kron_c<long, std::complex<double>, std::complex<double>>};
461+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_INT] = {eft_FLT, (void*)dpnp_kron_c<float, int, float>};
462+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_LNG] = {eft_FLT, (void*)dpnp_kron_c<float, long, float>};
463+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_kron_c<float, float, float>};
464+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_kron_c<float, double, double>};
465+
// fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_C128] = {
466+
// eft_C128, (void*)dpnp_kron_c<float, std::complex<double>, std::complex<double>>};
467+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_INT] = {eft_DBL, (void*)dpnp_kron_c<double, int, double>};
468+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_kron_c<double, long, double>};
469+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_kron_c<double, float, double>};
470+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_kron_c<double, double, double>};
471+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_C128] = {
472+
eft_C128, (void*)dpnp_kron_c<double, std::complex<double>, std::complex<double>>};
473+
// fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_INT] = {
474+
// eft_C128, (void*)dpnp_kron_c<std::complex<double>, int, std::complex<double>>};
475+
// fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_LNG] = {
476+
// eft_C128, (void*)dpnp_kron_c<std::complex<double>, long, std::complex<double>>};
477+
// fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_FLT] = {
478+
// eft_C128, (void*)dpnp_kron_c<std::complex<double>, float, std::complex<double>>};
479+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_DBL] = {
480+
eft_C128, (void*)dpnp_kron_c<std::complex<double>, double, std::complex<double>>};
481+
fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_C128] = {
482+
eft_C128, (void*)dpnp_kron_c<std::complex<double>, std::complex<double>, std::complex<double>>};
483+
382484
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_matrix_rank_c<int>};
383485
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_matrix_rank_c<long>};
384486
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matrix_rank_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
7373
DPNP_FN_HYPOT
7474
DPNP_FN_INV
7575
DPNP_FN_INVERT
76+
DPNP_FN_KRON
7677
DPNP_FN_LEFT_SHIFT
7778
DPNP_FN_LOG
7879
DPNP_FN_LOG10

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ cpdef dparray dpnp_array(obj, dtype=None):
9898
cdef elem_dtype
9999
cdef dparray_shape_type obj_shape
100100

101+
# convert scalar to tuple
102+
if dpnp.isscalar(obj):
103+
obj = (obj, )
104+
101105
if not cpython.PySequence_Check(obj):
102106
raise TypeError(f"DPNP array(): Unsupported non-sequence obj={type(obj)}")
103107

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@ cimport numpy
3939
__all__ += [
4040
"dpnp_dot",
4141
"dpnp_inner",
42+
"dpnp_kron",
4243
"dpnp_outer"
4344
]
4445

4546

47+
# C function pointer to the C library template functions
48+
ctypedef void(*fptr_2in_1out_shapes_t)(void * , void * , void * , size_t * , size_t * , size_t * , size_t)
49+
50+
4651
cpdef dparray dpnp_dot(dparray in_array1, dparray in_array2):
4752
cdef vector[Py_ssize_t] shape1 = in_array1.shape
4853
cdef vector[Py_ssize_t] shape2 = in_array2.shape
@@ -148,6 +153,45 @@ cpdef dparray dpnp_inner(dparray array1, dparray array2):
148153
return result
149154

150155

156+
cpdef dparray dpnp_kron(dparray in_array1, dparray in_array2):
157+
cdef size_t ndim = max(in_array1.ndim, in_array2.ndim)
158+
159+
cdef dparray_shape_type in_array1_shape
160+
if in_array1.ndim < ndim:
161+
for i in range(ndim - in_array1.ndim):
162+
in_array1_shape.push_back(1)
163+
for i in range(in_array1.ndim):
164+
in_array1_shape.push_back(in_array1.shape[i])
165+
166+
cdef dparray_shape_type in_array2_shape
167+
if in_array2.ndim < ndim:
168+
for i in range(ndim - in_array2.ndim):
169+
in_array2_shape.push_back(1)
170+
for i in range(in_array2.ndim):
171+
in_array2_shape.push_back(in_array2.shape[i])
172+
173+
cdef dparray_shape_type result_shape
174+
for i in range(ndim):
175+
result_shape.push_back(in_array1_shape[i] * in_array2_shape[i])
176+
177+
# convert string type names (dparray.dtype) to C enum DPNPFuncType
178+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
179+
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype)
180+
181+
# get the FPTR data structure
182+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_KRON, param1_type, param2_type)
183+
184+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
185+
# ceate result array with type given by FPTR data
186+
cdef dparray result = dparray(result_shape, dtype=result_type)
187+
188+
cdef fptr_2in_1out_shapes_t func = <fptr_2in_1out_shapes_t > kernel_data.ptr
189+
# call FPTR function
190+
func(in_array1.get_data(), in_array2.get_data(), result.get_data(), < size_t * > in_array1_shape.data(), < size_t * > in_array2_shape.data(), < size_t * > result_shape.data(), ndim)
191+
192+
return result
193+
194+
151195
cpdef dparray dpnp_outer(dparray array1, dparray array2):
152196
cdef dparray_shape_type result_shape = (array1.size, array2.size)
153197
result_type = numpy.promote_types(array1.dtype, array1.dtype)

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,21 +194,30 @@ def inner(x1, x2, **kwargs):
194194
return call_origin(numpy.inner, x1, x2, **kwargs)
195195

196196

197-
def kron(x1, x2, **kwargs):
197+
def kron(a, b):
198198
"""
199199
Returns the kronecker product of two arrays.
200200
201201
For full documentation refer to :obj:`numpy.kron`.
202202
203-
Limitations
204-
-----------
205-
Function is executed sequentially on CPU.
206-
207203
.. seealso:: :obj:`dpnp.outer` returns the outer product of two arrays.
208204
209205
"""
210206

211-
return call_origin(numpy.kron, x1, x2, **kwargs)
207+
if not use_origin_backend(a):
208+
if dpnp.isscalar(a):
209+
a = dpnp.array(a)
210+
if dpnp.isscalar(b):
211+
b = dpnp.array(b)
212+
213+
if not isinstance(a, dparray):
214+
pass
215+
elif not isinstance(b, dparray):
216+
pass
217+
else:
218+
return dpnp_kron(a, b)
219+
220+
return call_origin(numpy.kron, a, b)
212221

213222

214223
def outer(x1, x2, **kwargs):

tests/skipped_tests.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,13 +835,11 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matri
835835
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_dot_vec2
836836
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_vdot
837837
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_inner
838-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_kron
839838
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_outer
840839
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot
841840
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
842841
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous
843842
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot
844-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_zerodim_kron
845843
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_array_scalar
846844
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_finite
847845
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,13 +1034,11 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matri
10341034
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_dot_vec2
10351035
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_vdot
10361036
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_inner
1037-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_kron
10381037
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_outer
10391038
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot
10401039
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
10411040
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous
10421041
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot
1043-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_zerodim_kron
10441042
tests/third_party/cupy/linalg_tests/test_product.py::TestProductZeroLength_param_0_{params=((0, 0), 2)}::test_tensordot_zero_length
10451043
tests/third_party/cupy/linalg_tests/test_product.py::TestProductZeroLength_param_1_{params=((0, 0), (1, 0))}::test_tensordot_zero_length
10461044
tests/third_party/cupy/linalg_tests/test_product.py::TestProductZeroLength_param_2_{params=((0, 0, 0), 2)}::test_tensordot_zero_length

0 commit comments

Comments
 (0)