Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,6 @@ def dpnp_kron(a, b, a_ndim, b_ndim):

a_shape = a.shape
b_shape = b.shape
if not a.flags.contiguous:
a = dpnp.reshape(a, a_shape)
if not b.flags.contiguous:
b = dpnp.reshape(b, b_shape)

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape
Expand All @@ -693,7 +689,7 @@ def dpnp_kron(a, b, a_ndim, b_ndim):
ndim = max(b_ndim, a_ndim)
a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2)))
b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2)))
result = dpnp.multiply(a_arr, b_arr, order="C")
result = dpnp.multiply(a_arr, b_arr)

# Reshape back
return result.reshape(tuple(numpy.multiply(a_shape, b_shape)))
Expand Down
28 changes: 27 additions & 1 deletion tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def test_kron_input_dtype_matrix(self, dtype1, dtype2):
@pytest.mark.parametrize(
"stride", [3, -1, -2, -4], ids=["3", "-1", "-2", "-4"]
)
def test_kron_strided(self, dtype, stride):
def test_kron_strided1(self, dtype, stride):
a = numpy.arange(20, dtype=dtype)
b = numpy.arange(20, dtype=dtype)
ia = dpnp.array(a)
Expand All @@ -751,6 +751,32 @@ def test_kron_strided(self, dtype, stride):
expected = numpy.kron(a[::stride], b[::stride])
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("stride", [2, -1, -2], ids=["2", "-1", "-2"])
def test_kron_strided2(self, stride):
a = numpy.arange(48).reshape(6, 8)
b = numpy.arange(480).reshape(6, 8, 10)
ia = dpnp.array(a)
ib = dpnp.array(b)

result = dpnp.kron(
ia[::stride, ::stride], ib[::stride, ::stride, ::stride]
)
expected = numpy.kron(
a[::stride, ::stride], b[::stride, ::stride, ::stride]
)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("order", ["C", "F", "A"])
def test_kron_order(self, order):
a = numpy.arange(48).reshape(6, 8, order=order)
b = numpy.arange(480).reshape(6, 8, 10, order=order)
ia = dpnp.array(a)
ib = dpnp.array(b)

result = dpnp.kron(ia, ib)
expected = numpy.kron(a, b)
assert_dtype_allclose(result, expected)


class TestMultiDot:
def setup_method(self):
Expand Down
Loading