Skip to content

Commit 75c9b9a

Browse files
committed
Raise ValueError exception per axes keyword
1 parent d9931cb commit 75c9b9a

File tree

3 files changed

+55
-54
lines changed

3 files changed

+55
-54
lines changed

dpnp/dpnp_array.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,7 @@ def __imatmul__(self, other):
372372
if the result without `out` would have less dimensions than `a`.
373373
Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
374374
case exactly when the second operand has both core dimensions.
375-
376-
The error here will be confusing, but for now, we enforce this by
377-
passing the correct `axes=`.
375+
We have to enforce this check by passing the correct `axes=`.
378376
"""
379377
if self.ndim == 1:
380378
axes = [(-1,), (-2, -1), (-1,)]

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _validate_internal(axes, i, ndim):
563563

564564
if x1_ndim == 1 and x2_ndim == 1:
565565
if axes[2] != ():
566-
raise TypeError("Axes item 2 should be an empty tuple.")
566+
raise ValueError("Axes item 2 should be an empty tuple.")
567567
elif x1_ndim == 1 or x2_ndim == 1:
568568
axes[2] = _validate_internal(axes[2], 2, 1)
569569
else:

tests/test_mathematical.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4228,24 +4228,25 @@ def test_shapes(self, a_sh, b_sh):
42284228

42294229

42304230
class TestMatmulInvalidCases:
4231+
@pytest.mark.parametrize("xp", [numpy, dpnp])
42314232
@pytest.mark.parametrize(
4232-
"shape_pair",
4233+
"shape1, shape2",
42334234
[
42344235
((3, 2), ()),
42354236
((), (3, 2)),
42364237
((), ()),
42374238
],
42384239
)
4239-
def test_zero_dim(self, shape_pair):
4240-
for xp in (numpy, dpnp):
4241-
shape1, shape2 = shape_pair
4242-
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
4243-
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
4244-
with pytest.raises(ValueError):
4245-
xp.matmul(x1, x2)
4240+
def test_zero_dim(self, xp, shape1, shape2):
4241+
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
4242+
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
42464243

4244+
with pytest.raises(ValueError):
4245+
xp.matmul(x1, x2)
4246+
4247+
@pytest.mark.parametrize("xp", [numpy, dpnp])
42474248
@pytest.mark.parametrize(
4248-
"shape_pair",
4249+
"shape1, shape2",
42494250
[
42504251
((3,), (4,)),
42514252
((2, 3), (4, 5)),
@@ -4258,16 +4259,16 @@ def test_zero_dim(self, shape_pair):
42584259
((6, 5, 3, 2), (3, 2, 4)),
42594260
],
42604261
)
4261-
def test_invalid_shape(self, shape_pair):
4262-
for xp in (numpy, dpnp):
4263-
shape1, shape2 = shape_pair
4264-
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
4265-
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
4266-
with pytest.raises(ValueError):
4267-
xp.matmul(x1, x2)
4262+
def test_invalid_shape(self, xp, shape1, shape2):
4263+
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
4264+
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
42684265

4266+
with pytest.raises(ValueError):
4267+
xp.matmul(x1, x2)
4268+
4269+
@pytest.mark.parametrize("xp", [numpy, dpnp])
42694270
@pytest.mark.parametrize(
4270-
"shape_pair",
4271+
"shape1, shape2, out_shape",
42714272
[
42724273
((5, 4, 3), (3, 1), (3, 4, 1)),
42734274
((5, 4, 3), (3, 1), (5, 6, 1)),
@@ -4279,24 +4280,24 @@ def test_invalid_shape(self, shape_pair):
42794280
((4,), (3, 4, 5), (3, 6)),
42804281
],
42814282
)
4282-
def test_invalid_shape_out(self, shape_pair):
4283-
for xp in (numpy, dpnp):
4284-
shape1, shape2, out_shape = shape_pair
4285-
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
4286-
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
4287-
res = xp.empty(out_shape)
4288-
with pytest.raises(ValueError):
4289-
xp.matmul(x1, x2, out=res)
4283+
def test_invalid_shape_out(self, xp, shape1, shape2, out_shape):
4284+
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
4285+
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
4286+
res = xp.empty(out_shape)
42904287

4288+
with pytest.raises(ValueError):
4289+
xp.matmul(x1, x2, out=res)
4290+
4291+
@pytest.mark.parametrize("xp", [numpy, dpnp])
42914292
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)[:-2])
4292-
def test_invalid_dtype(self, dtype):
4293+
def test_invalid_dtype(self, xp, dtype):
42934294
dpnp_dtype = get_all_dtypes(no_none=True)[-1]
4294-
a1 = dpnp.arange(5 * 4, dtype=dpnp_dtype).reshape(5, 4)
4295-
a2 = dpnp.arange(7 * 4, dtype=dpnp_dtype).reshape(4, 7)
4296-
dp_out = dpnp.empty((5, 7), dtype=dtype)
4295+
a1 = xp.arange(5 * 4, dtype=dpnp_dtype).reshape(5, 4)
4296+
a2 = xp.arange(7 * 4, dtype=dpnp_dtype).reshape(4, 7)
4297+
dp_out = xp.empty((5, 7), dtype=dtype)
42974298

42984299
with pytest.raises(TypeError):
4299-
dpnp.matmul(a1, a2, out=dp_out)
4300+
xp.matmul(a1, a2, out=dp_out)
43004301

43014302
def test_exe_q(self):
43024303
x1 = dpnp.ones((5, 4), sycl_queue=dpctl.SyclQueue())
@@ -4310,13 +4311,14 @@ def test_exe_q(self):
43104311
with pytest.raises(ExecutionPlacementError):
43114312
dpnp.matmul(x1, x2, out=out)
43124313

4313-
def test_matmul_casting(self):
4314-
a1 = dpnp.arange(2 * 4, dtype=dpnp.float32).reshape(2, 4)
4315-
a2 = dpnp.arange(4 * 3).reshape(4, 3)
4314+
@pytest.mark.parametrize("xp", [numpy, dpnp])
4315+
def test_matmul_casting(self, xp):
4316+
a1 = xp.arange(2 * 4, dtype=xp.float32).reshape(2, 4)
4317+
a2 = xp.arange(4 * 3).reshape(4, 3)
43164318

4317-
res = dpnp.empty((2, 3), dtype=dpnp.int64)
4319+
res = xp.empty((2, 3), dtype=xp.int64)
43184320
with pytest.raises(TypeError):
4319-
dpnp.matmul(a1, a2, out=res, casting="safe")
4321+
xp.matmul(a1, a2, out=res, casting="safe")
43204322

43214323
def test_matmul_not_implemented(self):
43224324
a1 = dpnp.arange(2 * 4).reshape(2, 4)
@@ -4332,52 +4334,53 @@ def test_matmul_not_implemented(self):
43324334
with pytest.raises(NotImplementedError):
43334335
dpnp.matmul(a1, a2, axis=2)
43344336

4335-
def test_matmul_axes(self):
4336-
a1 = dpnp.arange(120).reshape(2, 5, 3, 4)
4337-
a2 = dpnp.arange(120).reshape(4, 2, 5, 3)
4337+
@pytest.mark.parametrize("xp", [numpy, dpnp])
4338+
def test_matmul_axes(self, xp):
4339+
a1 = xp.arange(120).reshape(2, 5, 3, 4)
4340+
a2 = xp.arange(120).reshape(4, 2, 5, 3)
43384341

43394342
# axes must be a list
43404343
axes = ((3, 1), (2, 0), (0, 1))
43414344
with pytest.raises(TypeError):
4342-
dpnp.matmul(a1, a2, axes=axes)
4345+
xp.matmul(a1, a2, axes=axes)
43434346

43444347
# axes must be be a list of three tuples
43454348
axes = [(3, 1), (2, 0)]
43464349
with pytest.raises(ValueError):
4347-
dpnp.matmul(a1, a2, axes=axes)
4350+
xp.matmul(a1, a2, axes=axes)
43484351

43494352
# axes item should be a tuple
43504353
axes = [(3, 1), (2, 0), [0, 1]]
43514354
with pytest.raises(TypeError):
4352-
dpnp.matmul(a1, a2, axes=axes)
4355+
xp.matmul(a1, a2, axes=axes)
43534356

43544357
# axes item should be a tuple with 2 elements
43554358
axes = [(3, 1), (2, 0), (0, 1, 2)]
43564359
with pytest.raises(ValueError):
4357-
dpnp.matmul(a1, a2, axes=axes)
4360+
xp.matmul(a1, a2, axes=axes)
43584361

43594362
# axes must be an integer
43604363
axes = [(3, 1), (2, 0), (0.0, 1)]
43614364
with pytest.raises(TypeError):
4362-
dpnp.matmul(a1, a2, axes=axes)
4365+
xp.matmul(a1, a2, axes=axes)
43634366

43644367
# axes item 2 should be an empty tuple
4365-
a = dpnp.arange(3)
4368+
a = xp.arange(3)
43664369
axes = [0, 0, 0]
4367-
with pytest.raises(TypeError):
4368-
dpnp.matmul(a, a, axes=axes)
4370+
with pytest.raises(ValueError):
4371+
xp.matmul(a, a, axes=axes)
43694372

4370-
a = dpnp.arange(3 * 4 * 5).reshape(3, 4, 5)
4371-
b = dpnp.arange(3)
4373+
a = xp.arange(3 * 4 * 5).reshape(3, 4, 5)
4374+
b = xp.arange(3)
43724375
# list object cannot be interpreted as an integer
43734376
axes = [(1, 0), (0), [0]]
43744377
with pytest.raises(TypeError):
4375-
dpnp.matmul(a, b, axes=axes)
4378+
xp.matmul(a, b, axes=axes)
43764379

43774380
# axes item should be a tuple with a single element, or an integer
43784381
axes = [(1, 0), (0), (0, 1)]
43794382
with pytest.raises(ValueError):
4380-
dpnp.matmul(a, b, axes=axes)
4383+
xp.matmul(a, b, axes=axes)
43814384

43824385

43834386
def test_elemenwise_nin_nout():

0 commit comments

Comments
 (0)