Skip to content

Commit 47dde87

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
Use np.ones to avoid signed integer overflow at run time
PiperOrigin-RevId: 738569856
1 parent f747112 commit 47dde87

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/pjit_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6360,8 +6360,8 @@ def f(x):
63606360
def test_intermediate_einsum(self, mesh):
63616361
shape1 = (8, 32, 1, 16)
63626362
shape2 = (8, 32, 1, 8)
6363-
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
6364-
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
6363+
np_inp1 = np.ones(math.prod(shape1)).reshape(shape1)
6364+
np_inp2 = np.ones(math.prod(shape2)).reshape(shape2)
63656365

63666366
s = NamedSharding(mesh, P('data'))
63676367
arr1 = jax.device_put(np_inp1, s)
@@ -6387,9 +6387,9 @@ def test_intermediate_einsum_auto_complete_spec(self, mesh):
63876387
shape1 = (8, 32, 2*16)
63886388
shape2 = (8, 32, 2, 8)
63896389
shape3 = (8, 32, 2, 8)
6390-
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
6391-
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
6392-
np_inp3 = np.arange(math.prod(shape3)).reshape(shape3)
6390+
np_inp1 = np.ones(math.prod(shape1)).reshape(shape1)
6391+
np_inp2 = np.ones(math.prod(shape2)).reshape(shape2)
6392+
np_inp3 = np.ones(math.prod(shape3)).reshape(shape3)
63936393

63946394
arr1 = jax.device_put(np_inp1, s)
63956395
arr2 = jax.device_put(np_inp2, s)
@@ -6436,8 +6436,8 @@ def f(condition, x, y):
64366436
def test_intermediate_einsum_conflict_error(self, mesh):
64376437
shape1 = (8, 32, 1, 16)
64386438
shape2 = (8, 32, 1, 8)
6439-
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
6440-
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
6439+
np_inp1 = np.ones(math.prod(shape1)).reshape(shape1)
6440+
np_inp2 = np.ones(math.prod(shape2)).reshape(shape2)
64416441

64426442
arr1 = jax.device_put(
64436443
np_inp1, NamedSharding(mesh, P(None, None, None, 'data')))

0 commit comments

Comments
 (0)