Skip to content

Commit d9931cb

Browse files
committed
Add support for inplace matrix multiplication
1 parent cb801da commit d9931cb

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

dpnp/dpnp_array.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __bool__(self):
205205
return self._array_obj.__bool__()
206206

207207
# '__class__',
208+
# `__class_getitem__`,
208209

209210
def __complex__(self):
210211
return self._array_obj.__complex__()
@@ -335,6 +336,8 @@ def __getitem__(self, key):
335336
res._array_obj = item
336337
return res
337338

339+
# '__getstate__',
340+
338341
def __gt__(self, other):
339342
"""Return ``self>value``."""
340343
return dpnp.greater(self, other)
@@ -361,7 +364,25 @@ def __ilshift__(self, other):
361364
dpnp.left_shift(self, other, out=self)
362365
return self
363366

364-
# '__imatmul__',
367+
def __imatmul__(self, other):
368+
"""Return ``self@=value``."""
369+
370+
"""
371+
Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
372+
if the result without `out` would have less dimensions than `a`.
373+
Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
374+
case exactly when the second operand has both core dimensions.
375+
376+
The error here will be confusing, but for now, we enforce this by
377+
passing the correct `axes=`.
378+
"""
379+
if self.ndim == 1:
380+
axes = [(-1,), (-2, -1), (-1,)]
381+
else:
382+
axes = [(-2, -1), (-2, -1), (-2, -1)]
383+
384+
dpnp.matmul(self, other, out=self, axes=axes)
385+
return self
365386

366387
def __imod__(self, other):
367388
"""Return ``self%=value``."""

tests/test_mathematical.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4168,6 +4168,65 @@ def test_matmul_with_offsets(self, sh1, sh2):
41684168
assert_array_equal(result, expected)
41694169

41704170

4171+
class TestMatmulInplace:
4172+
ALL_DTYPES = get_all_dtypes(no_none=True)
4173+
DTYPES = {}
4174+
for i in ALL_DTYPES:
4175+
for j in ALL_DTYPES:
4176+
if numpy.can_cast(j, i):
4177+
DTYPES[f"{i}-{j}"] = (i, j)
4178+
4179+
@pytest.mark.parametrize("dtype1, dtype2", DTYPES.values())
4180+
def test_basic(self, dtype1, dtype2):
4181+
a = numpy.arange(10).reshape(5, 2).astype(dtype1)
4182+
b = numpy.ones((2, 2), dtype=dtype2)
4183+
ia, ib = dpnp.array(a), dpnp.array(b)
4184+
ia_id = id(ia)
4185+
4186+
a @= b
4187+
ia @= ib
4188+
assert id(ia) == ia_id
4189+
assert_dtype_allclose(ia, a)
4190+
4191+
@pytest.mark.parametrize(
4192+
"a_sh, b_sh",
4193+
[
4194+
pytest.param((10**5, 10), (10, 10), id="2d_large"),
4195+
pytest.param((10**4, 10, 10), (1, 10, 10), id="3d_large"),
4196+
pytest.param((3,), (3,), id="1d"),
4197+
pytest.param((3, 3), (3,), id="2d_1d"),
4198+
pytest.param((3,), (3, 3), id="1d_2d"),
4199+
pytest.param((3, 3), (3, 1), id="2d_broadcast"),
4200+
pytest.param((1, 3), (3, 3), id="2d_broadcast_reverse"),
4201+
pytest.param((3, 3, 3), (1, 3, 1), id="3d_broadcast1"),
4202+
pytest.param((3, 3, 3), (1, 3, 3), id="3d_broadcast2"),
4203+
pytest.param((3, 3, 3), (3, 3, 1), id="3d_broadcast3"),
4204+
pytest.param((1, 3, 3), (3, 3, 3), id="3d_broadcast_reverse1"),
4205+
pytest.param((3, 1, 3), (3, 3, 3), id="3d_broadcast_reverse2"),
4206+
pytest.param((1, 1, 3), (3, 3, 3), id="3d_broadcast_reverse3"),
4207+
],
4208+
)
4209+
def test_shapes(self, a_sh, b_sh):
4210+
a_sz, b_sz = numpy.prod(a_sh), numpy.prod(b_sh)
4211+
a = numpy.arange(a_sz).reshape(a_sh).astype(numpy.float64)
4212+
b = numpy.arange(b_sz).reshape(b_sh)
4213+
4214+
ia, ib = dpnp.array(a), dpnp.array(b)
4215+
ia_id = id(ia)
4216+
4217+
expected = a @ b
4218+
if expected.shape != a_sh:
4219+
with pytest.raises(ValueError):
4220+
a @= b
4221+
4222+
with pytest.raises(ValueError):
4223+
ia @= ib
4224+
else:
4225+
ia @= ib
4226+
assert id(ia) == ia_id
4227+
assert_dtype_allclose(ia, expected)
4228+
4229+
41714230
class TestMatmulInvalidCases:
41724231
@pytest.mark.parametrize(
41734232
"shape_pair",

0 commit comments

Comments
 (0)