Skip to content

Commit 3e251c8

Browse files
authored
update dpnp.kron implementation (#1732)
* update dpnp.kron * address comments
1 parent 6c2036f commit 3e251c8

File tree

13 files changed

+226
-213
lines changed

13 files changed

+226
-213
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ enum class DPNPFuncName : size_t
174174
DPNP_FN_INV, /**< Used in numpy.linalg.inv() impl */
175175
DPNP_FN_INVERT, /**< Used in numpy.invert() impl */
176176
DPNP_FN_KRON, /**< Used in numpy.kron() impl */
177-
DPNP_FN_KRON_EXT, /**< Used in numpy.kron() impl, requires extra parameters
178-
*/
179177
DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */
180178
DPNP_FN_LOG, /**< Used in numpy.log() impl */
181179
DPNP_FN_LOG10, /**< Used in numpy.log10() impl */

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -499,18 +499,6 @@ void (*dpnp_kron_default_c)(void *,
499499
size_t) =
500500
dpnp_kron_c<_DataType1, _DataType2, _ResultType>;
501501

502-
template <typename _DataType1, typename _DataType2, typename _ResultType>
503-
DPCTLSyclEventRef (*dpnp_kron_ext_c)(DPCTLSyclQueueRef,
504-
void *,
505-
void *,
506-
void *,
507-
shape_elem_type *,
508-
shape_elem_type *,
509-
shape_elem_type *,
510-
size_t,
511-
const DPCTLEventVectorRef) =
512-
dpnp_kron_c<_DataType1, _DataType2, _ResultType>;
513-
514502
template <typename _DataType>
515503
DPCTLSyclEventRef
516504
dpnp_matrix_rank_c(DPCTLSyclQueueRef q_ref,
@@ -890,67 +878,6 @@ void func_map_init_linalg_func(func_map_t &fmap)
890878
(void *)dpnp_kron_default_c<std::complex<double>, std::complex<double>,
891879
std::complex<double>>};
892880

893-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_INT] = {
894-
eft_INT, (void *)dpnp_kron_ext_c<int32_t, int32_t, int32_t>};
895-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_LNG] = {
896-
eft_LNG, (void *)dpnp_kron_ext_c<int32_t, int64_t, int64_t>};
897-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_FLT] = {
898-
eft_FLT, (void *)dpnp_kron_ext_c<int32_t, float, float>};
899-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_DBL] = {
900-
eft_DBL, (void *)dpnp_kron_ext_c<int32_t, double, double>};
901-
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_C128] = {
902-
// eft_C128, (void*)dpnp_kron_ext_c<int32_t, std::complex<double>,
903-
// std::complex<double>>};
904-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_INT] = {
905-
eft_LNG, (void *)dpnp_kron_ext_c<int64_t, int32_t, int64_t>};
906-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_LNG] = {
907-
eft_LNG, (void *)dpnp_kron_ext_c<int64_t, int64_t, int64_t>};
908-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_FLT] = {
909-
eft_FLT, (void *)dpnp_kron_ext_c<int64_t, float, float>};
910-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_DBL] = {
911-
eft_DBL, (void *)dpnp_kron_ext_c<int64_t, double, double>};
912-
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_C128] = {
913-
// eft_C128, (void*)dpnp_kron_ext_c<int64_t, std::complex<double>,
914-
// std::complex<double>>};
915-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_INT] = {
916-
eft_FLT, (void *)dpnp_kron_ext_c<float, int32_t, float>};
917-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_LNG] = {
918-
eft_FLT, (void *)dpnp_kron_ext_c<float, int64_t, float>};
919-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_FLT] = {
920-
eft_FLT, (void *)dpnp_kron_ext_c<float, float, float>};
921-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_DBL] = {
922-
eft_DBL, (void *)dpnp_kron_ext_c<float, double, double>};
923-
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_C128] = {
924-
// eft_C128, (void*)dpnp_kron_ext_c<float, std::complex<double>,
925-
// std::complex<double>>};
926-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_INT] = {
927-
eft_DBL, (void *)dpnp_kron_ext_c<double, int32_t, double>};
928-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_LNG] = {
929-
eft_DBL, (void *)dpnp_kron_ext_c<double, int64_t, double>};
930-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_FLT] = {
931-
eft_DBL, (void *)dpnp_kron_ext_c<double, float, double>};
932-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_DBL] = {
933-
eft_DBL, (void *)dpnp_kron_ext_c<double, double, double>};
934-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_C128] = {
935-
eft_C128, (void *)dpnp_kron_ext_c<double, std::complex<double>,
936-
std::complex<double>>};
937-
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_INT] = {
938-
// eft_C128, (void*)dpnp_kron_ext_c<std::complex<double>, int32_t,
939-
// std::complex<double>>};
940-
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_LNG] = {
941-
// eft_C128, (void*)dpnp_kron_ext_c<std::complex<double>, int64_t,
942-
// std::complex<double>>};
943-
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_FLT] = {
944-
// eft_C128, (void*)dpnp_kron_ext_c<std::complex<double>, float,
945-
// std::complex<double>>};
946-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_DBL] = {
947-
eft_C128, (void *)dpnp_kron_ext_c<std::complex<double>, double,
948-
std::complex<double>>};
949-
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_C128] = {
950-
eft_C128,
951-
(void *)dpnp_kron_ext_c<std::complex<double>, std::complex<double>,
952-
std::complex<double>>};
953-
954881
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_INT][eft_INT] = {
955882
eft_INT, (void *)dpnp_matrix_rank_default_c<int32_t>};
956883
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

22
set(dpnp_algo_pyx_deps
3-
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_linearalgebra.pxi
43
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_statistics.pxi
54
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_trigonometric.pxi
65
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_sorting.pxi

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
7474
DPNP_FN_FMOD_EXT
7575
DPNP_FN_FULL
7676
DPNP_FN_FULL_LIKE
77-
DPNP_FN_KRON
78-
DPNP_FN_KRON_EXT
7977
DPNP_FN_MAXIMUM
8078
DPNP_FN_MAXIMUM_EXT
8179
DPNP_FN_MEDIAN

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ __all__ = [
6060

6161
include "dpnp_algo_arraycreation.pxi"
6262
include "dpnp_algo_indexing.pxi"
63-
include "dpnp_algo_linearalgebra.pxi"
6463
include "dpnp_algo_logic.pxi"
6564
include "dpnp_algo_mathematical.pxi"
6665
include "dpnp_algo_sorting.pxi"

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi

Lines changed: 0 additions & 106 deletions
This file was deleted.

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,12 @@
4444
import dpnp
4545

4646
# pylint: disable=no-name-in-module
47-
from .dpnp_algo import (
48-
dpnp_kron,
49-
)
5047
from .dpnp_utils import (
5148
call_origin,
5249
)
5350
from .dpnp_utils.dpnp_utils_linearalgebra import (
5451
dpnp_dot,
52+
dpnp_kron,
5553
dpnp_matmul,
5654
)
5755

@@ -305,22 +303,72 @@ def inner(a, b):
305303
return dpnp.tensordot(a, b, axes=(-1, -1))
306304

307305

308-
def kron(x1, x2):
306+
def kron(a, b):
309307
"""
310308
Returns the kronecker product of two arrays.
311309
312310
For full documentation refer to :obj:`numpy.kron`.
313311
314-
.. seealso:: :obj:`dpnp.outer` returns the outer product of two arrays.
312+
Parameters
313+
----------
314+
a : {dpnp.ndarray, usm_ndarray, scalar}
315+
First input array. Both inputs `a` and `b` can not be scalars
316+
at the same time.
317+
b : {dpnp.ndarray, usm_ndarray, scalar}
318+
Second input array. Both inputs `a` and `b` can not be scalars
319+
at the same time.
320+
321+
Returns
322+
-------
323+
out : dpnp.ndarray
324+
Returns the Kronecker product.
325+
326+
See Also
327+
--------
328+
:obj:`dpnp.outer` : Returns the outer product of two arrays.
329+
330+
Examples
331+
--------
332+
>>> import dpnp as np
333+
>>> a = np.array([1, 10, 100])
334+
>>> b = np.array([5, 6, 7])
335+
>>> np.kron(a, b)
336+
array([ 5, 6, 7, ..., 500, 600, 700])
337+
>>> np.kron(b, a)
338+
array([ 5, 50, 500, ..., 7, 70, 700])
339+
340+
>>> np.kron(np.eye(2), np.ones((2,2)))
341+
array([[1., 1., 0., 0.],
342+
[1., 1., 0., 0.],
343+
[0., 0., 1., 1.],
344+
[0., 0., 1., 1.]])
345+
346+
>>> a = np.arange(100).reshape((2,5,2,5))
347+
>>> b = np.arange(24).reshape((2,3,4))
348+
>>> c = np.kron(a,b)
349+
>>> c.shape
350+
(2, 10, 6, 20)
351+
>>> I = (1,3,0,2)
352+
>>> J = (0,2,1)
353+
>>> J1 = (0,) + J # extend to ndim=4
354+
>>> S1 = (1,) + b.shape
355+
>>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
356+
>>> c[K] == a[I]*b[J]
357+
array(True)
315358
316359
"""
317360

318-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
319-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
320-
if x1_desc and x2_desc:
321-
return dpnp_kron(x1_desc, x2_desc).get_pyobj()
361+
dpnp.check_supported_arrays_type(a, b, scalar_type=True)
362+
363+
if dpnp.isscalar(a) or dpnp.isscalar(b):
364+
return dpnp.multiply(a, b)
365+
366+
a_ndim = a.ndim
367+
b_ndim = b.ndim
368+
if a_ndim == 0 or b_ndim == 0:
369+
return dpnp.multiply(a, b)
322370

323-
return call_origin(numpy.kron, x1, x2)
371+
return dpnp_kron(a, b, a_ndim, b_ndim)
324372

325373

326374
def matmul(

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from dpnp.dpnp_array import dpnp_array
3636
from dpnp.dpnp_utils import get_usm_allocations
3737

38-
__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_matmul"]
38+
__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul"]
3939

4040

4141
def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
@@ -476,6 +476,34 @@ def dpnp_cross(a, b, cp, exec_q):
476476
return cp
477477

478478

479+
def dpnp_kron(a, b, a_ndim, b_ndim):
480+
"""Returns the kronecker product of two arrays."""
481+
482+
a_shape = a.shape
483+
b_shape = b.shape
484+
if not a.flags.contiguous:
485+
a = dpnp.reshape(a, a_shape)
486+
if not b.flags.contiguous:
487+
b = dpnp.reshape(b, b_shape)
488+
489+
# Equalise the shapes by prepending smaller one with 1s
490+
a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape
491+
b_shape = (1,) * max(0, a_ndim - b_ndim) + b_shape
492+
493+
# Insert empty dimensions
494+
a_arr = dpnp.expand_dims(a, axis=tuple(range(b_ndim - a_ndim)))
495+
b_arr = dpnp.expand_dims(b, axis=tuple(range(a_ndim - b_ndim)))
496+
497+
# Compute the product
498+
ndim = max(b_ndim, a_ndim)
499+
a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2)))
500+
b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2)))
501+
result = dpnp.multiply(a_arr, b_arr)
502+
503+
# Reshape back
504+
return result.reshape(tuple(numpy.multiply(a_shape, b_shape)))
505+
506+
479507
def dpnp_dot(a, b, /, out=None, *, conjugate=False):
480508
"""
481509
Return the dot product of two arrays.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def _multi_dot(arrays, order, i, j, out=None):
401401

402402
def _multi_dot_matrix_chain_order(n, arrays, return_costs=False):
403403
"""
404-
Return a dpnp.ndarray that encodes the optimal order of mutiplications.
404+
Return a dpnp.ndarray that encodes the optimal order of multiplications.
405405
406406
The optimal order array is then used by `_multi_dot()` to do the
407407
multiplication.

0 commit comments

Comments
 (0)