Skip to content

Commit 833c7ba

Browse files
Google-ML-Automationjax authors
authored andcommitted
Allow generation of sharding strategies with mixed mesh shapes by default.
PiperOrigin-RevId: 641930205
1 parent 0739d52 commit 833c7ba

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/pjit_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,7 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,
15641564

15651565
def test_xla_arr_sharding_mismatch(self):
15661566
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
1567-
global_input_shape = (4, 2)
1567+
global_input_shape = (6, 2)
15681568
input_data = np.arange(
15691569
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
15701570

0 commit comments

Comments
 (0)