@@ -31,8 +31,8 @@ def _rand1_shape(shape, prob):
3131def augment_einsum_testcases (* params ):
3232 """Modify shapes in einsum tests
3333
34- Shape parameter should be starts with " shape_" .
35- The original parameter is stored as " _raw_params" .
34+ Shape parameter should be starts with ' shape_' .
35+ The original parameter is stored as ' _raw_params' .
3636
3737 Args:
3838 params (sequence of dicts)
@@ -61,6 +61,7 @@ def augment_einsum_testcases(*params):
6161
6262
6363class TestEinSumError :
64+
6465 def test_irregular_ellipsis1 (self ):
6566 for xp in (numpy , cupy ):
6667 with pytest .raises (ValueError ):
@@ -233,6 +234,7 @@ def test_invalid_arrow4(self):
233234
234235
235236class TestListArgEinSumError :
237+
236238 @testing .with_requires ("numpy>=1.19" )
237239 def test_invalid_sub1 (self ):
238240 for xp in (numpy , cupy ):
@@ -338,6 +340,7 @@ def test_numpy_15961_list(self, xp, do_opt):
338340 )
339341)
340342class TestEinSumUnaryOperation :
343+
341344 @testing .for_all_dtypes (no_bool = False )
342345 @testing .numpy_cupy_allclose (
343346 rtol = {numpy .float16 : 1e-1 , "default" : 1e-7 }, contiguous_check = False
@@ -350,13 +353,15 @@ def test_einsum_unary(self, xp, dtype):
350353 testing .assert_allclose (optimized_out , out )
351354 return out
352355
353- @pytest .mark .skip ("view is not supported" )
354356 @testing .for_all_dtypes (no_bool = False )
355357 @testing .numpy_cupy_equal ()
356358 def test_einsum_unary_views (self , xp , dtype ):
357359 a = testing .shaped_arange (self .shape_a , xp , dtype )
358360 b = xp .einsum (self .subscripts , a )
359-
361+ if xp is cupy :
362+ return (
363+ b .ndim == 0 or b .get_array ()._pointer == a .get_array ()._pointer
364+ )
360365 return b .ndim == 0 or b .base is a
361366
362367 @testing .for_all_dtypes_combination (
@@ -373,13 +378,13 @@ def test_einsum_unary_dtype(self, xp, dtype_a, dtype_out):
373378
374379
375380class TestEinSumUnaryOperationWithScalar :
376- @pytest .mark .skip ("All operands are scalar. " )
381+ @pytest .mark .skip ("Scalar input is not supported " )
377382 @testing .for_all_dtypes ()
378383 @testing .numpy_cupy_allclose ()
379384 def test_scalar_int (self , xp , dtype ):
380385 return xp .asarray (xp .einsum ("->" , 2 , dtype = dtype ))
381386
382- @pytest .mark .skip ("All operands are scalar. " )
387+ @pytest .mark .skip ("Scalar input is not supported " )
383388 @testing .for_all_dtypes ()
384389 @testing .numpy_cupy_allclose ()
385390 def test_scalar_float (self , xp , dtype ):
@@ -574,7 +579,7 @@ def test_einsum_ternary(self, xp, dtype_a, dtype_b, dtype_c):
574579
575580 if xp is not numpy : # Avoid numpy issues #11059, #11060
576581 for optimize in [
577- True , # " greedy"
582+ True , # ' greedy'
578583 "optimal" ,
579584 ["einsum_path" , (0 , 1 ), (0 , 1 )],
580585 ["einsum_path" , (0 , 2 ), (0 , 1 )],
@@ -616,6 +621,7 @@ def test_einsum_ternary(self, xp, dtype_a, dtype_b, dtype_c):
616621 )
617622)
618623class TestEinSumLarge :
624+
619625 chars = "abcdefghij"
620626 sizes = (2 , 3 , 4 , 5 , 4 , 3 , 2 , 6 , 5 , 4 , 3 )
621627 size_dict = {}
@@ -638,7 +644,7 @@ def test_einsum(self, xp, shapes):
638644 ]
639645 # TODO(kataoka): support memory efficient cupy.einsum
640646 with warnings .catch_warnings (record = True ) as ws :
641- # I hope there" s no problem with np.einsum for these cases...
647+ # I hope there' s no problem with np.einsum for these cases...
642648 out = xp .einsum (self .subscript , * arrays , optimize = self .opt )
643649 if xp is not numpy and isinstance (
644650 self .opt , tuple
0 commit comments