@@ -111,8 +111,6 @@ def __repr__(self):
111111 return self .name
112112
113113def axis_names_to_types (axis_types ) -> dict [str , AxisTypes ]:
114- if axis_types is None :
115- return {}
116114 d = {}
117115 for t , names in axis_types .items ():
118116 if isinstance (names , tuple ):
@@ -179,7 +177,7 @@ class Mesh(contextlib.ContextDecorator):
179177
180178 devices : np .ndarray
181179 axis_names : tuple [MeshAxisName , ...]
182- axis_types : MeshAxisType | None
180+ axis_types : MeshAxisType
183181
184182 def __new__ (cls , devices : np .ndarray | Sequence [xc .Device ],
185183 axis_names : str | Sequence [MeshAxisName ], * ,
@@ -199,9 +197,9 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
199197 f"devices.ndim == { devices .ndim } and "
200198 f"len(axis_names) == { len (axis_names )} ." )
201199
202- # TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
203- axis_types_tuple = ( None if axis_types is None else
204- tuple (axis_types .items () ))
200+ axis_types = ({ AxisTypes . Auto : axis_names } if axis_types is None else
201+ axis_types )
202+ axis_types_tuple = tuple (axis_types .items ())
205203 key = (axis_names , devices .shape , tuple (devices .flat ), axis_types_tuple )
206204 val = _mesh_object_dict .get (key , None )
207205 if val is not None :
@@ -337,7 +335,7 @@ def __str__(self):
337335 def _repr (self ):
338336 if self .empty :
339337 return "Mesh(device_ids=[], axis_names=())"
340- atr = '' if self . axis_types is None else f", axis_types={ self .axis_types } "
338+ atr = f", axis_types={ self .axis_types } "
341339 return f"Mesh(device_ids={ self .device_ids !r} , axis_names={ self .axis_names !r} { atr } )"
342340
343341 def __repr__ (self ):
@@ -378,14 +376,13 @@ class AbstractMesh:
378376 def __init__ (self , shape_tuple : tuple [tuple [str , int ], ...], * ,
379377 axis_types : MeshAxisType | None = None ):
380378 self .shape_tuple = shape_tuple
381- self .axis_types = axis_types
382379 if self .shape_tuple :
383380 self ._axis_names , self ._axis_sizes = list (zip (* self .shape_tuple ))
384381 else :
385382 self ._axis_names , self ._axis_sizes = (), ()
386- # TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
387- self . _axis_types_tuple = ( None if axis_types is None else
388- tuple (axis_types .items () ))
383+ self . axis_types = ({ AxisTypes . Auto : self . _axis_names } if axis_types is None
384+ else axis_types )
385+ self . _axis_types_tuple = tuple (self . axis_types .items ())
389386
390387 def __hash__ (self ):
391388 return hash ((self .shape_tuple , self ._axis_types_tuple ))
@@ -399,7 +396,7 @@ def __eq__(self, other):
399396 self ._axis_types_tuple == other ._axis_types_tuple )
400397
401398 def __repr__ (self ):
402- atr = '' if self . axis_types is None else f", axis_types={ self .axis_types } "
399+ atr = f", axis_types={ self .axis_types } "
403400 return f"AbstractMesh({ self .shape_tuple } { atr } )"
404401
405402 @property
@@ -432,26 +429,18 @@ def empty(self):
432429
433430 @functools .cached_property
434431 def _are_all_axes_collective (self ) -> bool :
435- if self .axis_types is None :
436- return False
437432 return all (t == AxisTypes .Collective for t in self .axis_types .keys ())
438433
439434 @functools .cached_property
440435 def _are_all_axes_auto (self ) -> bool :
441- if self .axis_types is None :
442- return False
443436 return all (t == AxisTypes .Auto for t in self .axis_types .keys ())
444437
445438 @functools .cached_property
446439 def _any_axis_collective (self ) -> bool :
447- if self .axis_types is None :
448- return False
449440 return any (t == AxisTypes .Collective for t in self .axis_types .keys ())
450441
451442 @functools .cached_property
452443 def _any_axis_auto (self ) -> bool :
453- if self .axis_types is None :
454- return False
455444 return any (t == AxisTypes .Auto for t in self .axis_types .keys ())
456445
457446 @property
@@ -494,8 +483,6 @@ def _raise_value_error(name):
494483
495484@contextlib .contextmanager
496485def set_abstract_mesh (mesh : AbstractMesh ):
497- if mesh is not None and mesh .axis_types is None :
498- raise RuntimeError ('Please set the AxisTypes of Mesh.' )
499486 prev_val = jax_config .abstract_mesh_context_manager .swap_local (mesh )
500487 try :
501488 yield
0 commit comments