@@ -6360,8 +6360,8 @@ def f(x):
63606360 def test_intermediate_einsum (self , mesh ):
63616361 shape1 = (8 , 32 , 1 , 16 )
63626362 shape2 = (8 , 32 , 1 , 8 )
6363- np_inp1 = np .arange (math .prod (shape1 )).reshape (shape1 )
6364- np_inp2 = np .arange (math .prod (shape2 )).reshape (shape2 )
6363+ np_inp1 = np .ones (math .prod (shape1 )).reshape (shape1 )
6364+ np_inp2 = np .ones (math .prod (shape2 )).reshape (shape2 )
63656365
63666366 s = NamedSharding (mesh , P ('data' ))
63676367 arr1 = jax .device_put (np_inp1 , s )
@@ -6387,9 +6387,9 @@ def test_intermediate_einsum_auto_complete_spec(self, mesh):
63876387 shape1 = (8 , 32 , 2 * 16 )
63886388 shape2 = (8 , 32 , 2 , 8 )
63896389 shape3 = (8 , 32 , 2 , 8 )
6390- np_inp1 = np .arange (math .prod (shape1 )).reshape (shape1 )
6391- np_inp2 = np .arange (math .prod (shape2 )).reshape (shape2 )
6392- np_inp3 = np .arange (math .prod (shape3 )).reshape (shape3 )
6390+ np_inp1 = np .ones (math .prod (shape1 )).reshape (shape1 )
6391+ np_inp2 = np .ones (math .prod (shape2 )).reshape (shape2 )
6392+ np_inp3 = np .ones (math .prod (shape3 )).reshape (shape3 )
63936393
63946394 arr1 = jax .device_put (np_inp1 , s )
63956395 arr2 = jax .device_put (np_inp2 , s )
@@ -6436,8 +6436,8 @@ def f(condition, x, y):
64366436 def test_intermediate_einsum_conflict_error (self , mesh ):
64376437 shape1 = (8 , 32 , 1 , 16 )
64386438 shape2 = (8 , 32 , 1 , 8 )
6439- np_inp1 = np .arange (math .prod (shape1 )).reshape (shape1 )
6440- np_inp2 = np .arange (math .prod (shape2 )).reshape (shape2 )
6439+ np_inp1 = np .ones (math .prod (shape1 )).reshape (shape1 )
6440+ np_inp2 = np .ones (math .prod (shape2 )).reshape (shape2 )
64416441
64426442 arr1 = jax .device_put (
64436443 np_inp1 , NamedSharding (mesh , P (None , None , None , 'data' )))
0 commit comments