@@ -112,20 +112,17 @@ class NamedSharding(JSharding.Sharding):
112112 mesh : mesh_lib .Mesh | mesh_lib .AbstractMesh
113113 spec : PartitionSpec
114114 _memory_kind : str | None
115- _manual_axes : frozenset [MeshAxisName ]
116115 _logical_device_ids : tuple [int , ...] | None
117116
118117 @use_cpp_method ()
119118 def __init__ (
120119 self , mesh : mesh_lib .Mesh | mesh_lib .AbstractMesh , spec : PartitionSpec , * ,
121- memory_kind : str | None = None , _manual_axes = frozenset (),
122- _logical_device_ids = None ):
120+ memory_kind : str | None = None , _logical_device_ids = None ):
123121 self .mesh = mesh
124122 self .spec = spec
125123 self ._memory_kind = memory_kind
126- self ._manual_axes = _manual_axes
127124 self ._logical_device_ids = _logical_device_ids
128- check_pspec (self .mesh , self .spec , self . _manual_axes )
125+ check_pspec (self .mesh , self .spec )
129126
130127 def __repr__ (self ):
131128 mem = '' if self .memory_kind is None else f', memory_kind={ self .memory_kind } '
@@ -137,7 +134,6 @@ def __repr__(self):
137134 def __reduce__ (self ):
138135 return (type (self ), (self .mesh , self .spec ),
139136 {'memory_kind' : self .memory_kind ,
140- '_manual_axes' : self ._manual_axes ,
141137 '_logical_device_ids' : self ._logical_device_ids })
142138
143139 @property
@@ -147,8 +143,7 @@ def memory_kind(self) -> str | None:
147143 def __hash__ (self ):
148144 if not hasattr (self , '_hash' ):
149145 self ._hash = hash (
150- (self .mesh , self .memory_kind , self .spec , self ._manual_axes ,
151- self ._logical_device_ids ))
146+ (self .mesh , self .memory_kind , self .spec , self ._logical_device_ids ))
152147 return self ._hash
153148
154149 def __eq__ (self , other ):
@@ -158,7 +153,6 @@ def __eq__(self, other):
158153 return True
159154 if (self .spec != other .spec
160155 or self .memory_kind != other .memory_kind
161- or self ._manual_axes != other ._manual_axes
162156 or self ._logical_device_ids != other ._logical_device_ids ):
163157 return False
164158 return self .mesh is other .mesh or self .mesh == other .mesh
@@ -333,9 +327,7 @@ def named_sharding_to_xla_hlo_sharding(
333327 mesh_axis_pos = {name : i for i , name in enumerate (self .mesh .axis_names )}
334328
335329 special_axes = {}
336- mesh_manual_axes = {n for n , t in self .mesh ._name_to_type .items ()
337- if t == mesh_lib .AxisType .Manual }
338- manual_axes = self ._manual_axes .union (mesh_manual_axes )
330+ manual_axes = frozenset (self .mesh .manual_axes )
339331 if manual_axes :
340332 axis_names = self .mesh .axis_names
341333 for manual_axis in manual_axes :
@@ -420,7 +412,7 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
420412@cache (max_size = 128 , trace_context_in_key = False )
421413def check_pspec (mesh , spec , _manual_axes = frozenset ()):
422414 _check_unique_resources (spec , "NamedSharding spec" , mesh )
423- _check_mesh_resource_axis (mesh , spec , _manual_axes )
415+ _check_mesh_resource_axis (mesh , spec )
424416
425417class DuplicateSpecError (Exception ):
426418 def __init__ (self , message , mesh , pspec ):
@@ -455,7 +447,7 @@ def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None
455447 mesh = mesh , pspec = pspec )
456448
457449
458- def _check_mesh_resource_axis (mesh , pspec , _manual_axes ):
450+ def _check_mesh_resource_axis (mesh , pspec ):
459451 for p in pspec :
460452 if p is PartitionSpec .UNCONSTRAINED or p is None :
461453 continue
@@ -465,10 +457,6 @@ def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
465457 raise ValueError (
466458 f"Resource axis: { r } of { pspec } "
467459 f"is not found in mesh: { tuple (mesh .shape .keys ())} ." )
468- if r in _manual_axes :
469- raise ValueError (
470- f"Axis: { r } of { pspec } "
471- f"is also found in manual_axes: { _manual_axes } ." ) from None
472460 if not all (mesh ._name_to_type [p [0 ]] == mesh ._name_to_type [r ] for r in p ):
473461 raise ValueError (
474462 'AxisTypes should be the same in a tuple subset of PartitionSpec:'
0 commit comments