Skip to content

Commit cbcc883

Browse files
bartchr808Google-ML-Automation
authored andcommitted
#sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
1 parent 196eec8 commit cbcc883

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

jax/_src/sharding_impls.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ def build(self) -> sdy.DimensionShardingAttr:
125125
is_closed=self.is_closed,
126126
priority=self.priority)
127127

128+
def __repr__(self):
129+
return f'SdyDimSharding({self._custom_repr()})'
130+
131+
def _custom_repr(self):
132+
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
133+
open_repr = ''
134+
if not self.is_closed:
135+
open_repr = ', ?' if self.axes else '?'
136+
priority_repr = '' if self.priority is None else f'p{self.priority}'
137+
return f'{{{axes_repr}{open_repr}}}{priority_repr}'
138+
128139

129140
@dataclasses.dataclass
130141
class SdyArraySharding:
@@ -146,6 +157,13 @@ def build(self) -> sdy.TensorShardingAttr:
146157
mesh_attr,
147158
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
148159

160+
def __repr__(self):
161+
dim_sharding_repr = ', '.join(
162+
d._custom_repr() for d in self.dimension_shardings)
163+
device_id_repr = (f', device_ids={self.logical_device_ids}'
164+
if self.logical_device_ids is not None else '')
165+
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"
166+
149167

150168
@util.cache(max_size=4096, trace_context_in_key=False)
151169
def named_sharding_to_xla_hlo_sharding(

tests/pjit_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6632,6 +6632,28 @@ def f(x, y):
66326632
lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text()
66336633
self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str)
66346634

6635+
def test_array_sharding_repr_with_priority(self):
6636+
sharding = sharding_impls.SdyArraySharding(
6637+
mesh_shape=(('data', 4), ('model', 8), ('expert', 2)),
6638+
dimension_shardings=[
6639+
sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True),
6640+
sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)])
6641+
self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])")
6642+
6643+
def test_array_sharding_repr_with_logical_ids(self):
6644+
abstract_mesh = jax.sharding.AbstractMesh((('x', 4), ('y', 8), ('z', 2)))
6645+
ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None),
6646+
_logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3])
6647+
self.assertEqual(repr(ns._to_sdy_sharding(4)),
6648+
"SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], "
6649+
"device_ids=[4, 5, 6, 7, 0, 1, 2, 3])")
6650+
6651+
def test_dimension_sharding_repr(self):
6652+
dim_sharding = sharding_impls.SdyDimSharding(
6653+
axes=['data', 'model'], is_closed=False, priority=2)
6654+
self.assertEqual(repr(dim_sharding),
6655+
"SdyDimSharding({'data', 'model', ?}p2)")
6656+
66356657

66366658
if __name__ == '__main__':
66376659
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)