Skip to content

Commit e3ac240

Browse files
committed
Update linalg_tests/test_einsum.py
1 parent eb12f2d commit e3ac240

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_einsum.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def _rand1_shape(shape, prob):
3131
def augment_einsum_testcases(*params):
3232
"""Modify shapes in einsum tests
3333
34-
Shape parameter should be starts with "shape_".
35-
The original parameter is stored as "_raw_params".
34+
Shape parameter should be starts with 'shape_'.
35+
The original parameter is stored as '_raw_params'.
3636
3737
Args:
3838
params (sequence of dicts)
@@ -61,6 +61,7 @@ def augment_einsum_testcases(*params):
6161

6262

6363
class TestEinSumError:
64+
6465
def test_irregular_ellipsis1(self):
6566
for xp in (numpy, cupy):
6667
with pytest.raises(ValueError):
@@ -233,6 +234,7 @@ def test_invalid_arrow4(self):
233234

234235

235236
class TestListArgEinSumError:
237+
236238
@testing.with_requires("numpy>=1.19")
237239
def test_invalid_sub1(self):
238240
for xp in (numpy, cupy):
@@ -338,6 +340,7 @@ def test_numpy_15961_list(self, xp, do_opt):
338340
)
339341
)
340342
class TestEinSumUnaryOperation:
343+
341344
@testing.for_all_dtypes(no_bool=False)
342345
@testing.numpy_cupy_allclose(
343346
rtol={numpy.float16: 1e-1, "default": 1e-7}, contiguous_check=False
@@ -350,13 +353,15 @@ def test_einsum_unary(self, xp, dtype):
350353
testing.assert_allclose(optimized_out, out)
351354
return out
352355

353-
@pytest.mark.skip("view is not supported")
354356
@testing.for_all_dtypes(no_bool=False)
355357
@testing.numpy_cupy_equal()
356358
def test_einsum_unary_views(self, xp, dtype):
357359
a = testing.shaped_arange(self.shape_a, xp, dtype)
358360
b = xp.einsum(self.subscripts, a)
359-
361+
if xp is cupy:
362+
return (
363+
b.ndim == 0 or b.get_array()._pointer == a.get_array()._pointer
364+
)
360365
return b.ndim == 0 or b.base is a
361366

362367
@testing.for_all_dtypes_combination(
@@ -373,13 +378,13 @@ def test_einsum_unary_dtype(self, xp, dtype_a, dtype_out):
373378

374379

375380
class TestEinSumUnaryOperationWithScalar:
376-
@pytest.mark.skip("All operands are scalar.")
381+
@pytest.mark.skip("Scalar input is not supported")
377382
@testing.for_all_dtypes()
378383
@testing.numpy_cupy_allclose()
379384
def test_scalar_int(self, xp, dtype):
380385
return xp.asarray(xp.einsum("->", 2, dtype=dtype))
381386

382-
@pytest.mark.skip("All operands are scalar.")
387+
@pytest.mark.skip("Scalar input is not supported")
383388
@testing.for_all_dtypes()
384389
@testing.numpy_cupy_allclose()
385390
def test_scalar_float(self, xp, dtype):
@@ -574,7 +579,7 @@ def test_einsum_ternary(self, xp, dtype_a, dtype_b, dtype_c):
574579

575580
if xp is not numpy: # Avoid numpy issues #11059, #11060
576581
for optimize in [
577-
True, # "greedy"
582+
True, # 'greedy'
578583
"optimal",
579584
["einsum_path", (0, 1), (0, 1)],
580585
["einsum_path", (0, 2), (0, 1)],
@@ -616,6 +621,7 @@ def test_einsum_ternary(self, xp, dtype_a, dtype_b, dtype_c):
616621
)
617622
)
618623
class TestEinSumLarge:
624+
619625
chars = "abcdefghij"
620626
sizes = (2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3)
621627
size_dict = {}
@@ -638,7 +644,7 @@ def test_einsum(self, xp, shapes):
638644
]
639645
# TODO(kataoka): support memory efficient cupy.einsum
640646
with warnings.catch_warnings(record=True) as ws:
641-
# I hope there"s no problem with np.einsum for these cases...
647+
# I hope there's no problem with np.einsum for these cases...
642648
out = xp.einsum(self.subscript, *arrays, optimize=self.opt)
643649
if xp is not numpy and isinstance(
644650
self.opt, tuple

0 commit comments

Comments
 (0)