@@ -6632,6 +6632,28 @@ def f(x, y):
66326632 lowered_str = jax .jit (f , in_shardings = [AUTO (mesh ), AUTO (mesh )]).lower (x , x ).as_text ()
66336633 self .assertIn ('sdy.mesh @mesh = <["x"=8]>' , lowered_str )
66346634
6635+ def test_array_sharding_repr_with_priority (self ):
6636+ sharding = sharding_impls .SdyArraySharding (
6637+ mesh_shape = (('data' , 4 ), ('model' , 8 ), ('expert' , 2 )),
6638+ dimension_shardings = [
6639+ sharding_impls .SdyDimSharding (axes = ['data' , 'expert' ], is_closed = True ),
6640+ sharding_impls .SdyDimSharding (axes = ['model' ], is_closed = False , priority = 2 )])
6641+ self .assertEqual (repr (sharding ), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])" )
6642+
6643+ def test_array_sharding_repr_with_logical_ids (self ):
6644+ abstract_mesh = jax .sharding .AbstractMesh ((('x' , 4 ), ('y' , 8 ), ('z' , 2 )))
6645+ ns = NamedSharding (abstract_mesh , P (('x' , 'y' ), 'z' , P .UNCONSTRAINED , None ),
6646+ _logical_device_ids = [4 , 5 , 6 , 7 , 0 , 1 , 2 , 3 ])
6647+ self .assertEqual (repr (ns ._to_sdy_sharding (4 )),
6648+ "SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], "
6649+ "device_ids=[4, 5, 6, 7, 0, 1, 2, 3])" )
6650+
6651+ def test_dimension_sharding_repr (self ):
6652+ dim_sharding = sharding_impls .SdyDimSharding (
6653+ axes = ['data' , 'model' ], is_closed = False , priority = 2 )
6654+ self .assertEqual (repr (dim_sharding ),
6655+ "SdyDimSharding({'data', 'model', ?}p2)" )
6656+
66356657
66366658if __name__ == '__main__' :
66376659 absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments