1818import collections
1919from collections .abc import Hashable , Sequence
2020import contextlib
21+ import enum
2122import functools
2223import math
2324import threading
@@ -101,6 +102,12 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
101102 return Mesh (global_mesh .devices [subcube_indices_tuple ], global_mesh .axis_names )
102103
103104
105+ class AxisTypes (enum .Enum ):
106+ Auto = enum .auto ()
107+ User = enum .auto ()
108+ Collective = enum .auto ()
109+
110+
104111_mesh_object_dict = {} # type: ignore
105112
106113
@@ -157,9 +164,11 @@ class Mesh(contextlib.ContextDecorator):
157164
158165 devices : np .ndarray
159166 axis_names : tuple [MeshAxisName , ...]
167+ axis_types : dict [AxisTypes , str | tuple [str , ...]] | None
160168
161169 def __new__ (cls , devices : np .ndarray | Sequence [xc .Device ],
162- axis_names : str | Sequence [MeshAxisName ]):
170+ axis_names : str | Sequence [MeshAxisName ],
171+ axis_types : dict [AxisTypes , str | tuple [str , ...]] | None = None ):
163172 if not isinstance (devices , np .ndarray ):
164173 devices = np .array (devices )
165174 if isinstance (axis_names , str ):
@@ -175,7 +184,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
175184 f"devices.ndim == { devices .ndim } and "
176185 f"len(axis_names) == { len (axis_names )} ." )
177186
178- key = (axis_names , devices .shape , tuple (devices .flat ))
187+ # TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
188+ axis_types_tuple = (None if axis_types is None else
189+ tuple (axis_types .items ()))
190+ key = (axis_names , devices .shape , tuple (devices .flat ), axis_types_tuple )
179191 val = _mesh_object_dict .get (key , None )
180192 if val is not None :
181193 return val
@@ -184,11 +196,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
184196 self .devices = devices .copy ()
185197 self .devices .flags .writeable = False
186198 self .axis_names = axis_names
199+ self .axis_types = axis_types
200+ self ._axis_types_tuple = axis_types_tuple
187201 _mesh_object_dict [key ] = self
188202 return self
189203
190204 def __reduce__ (self ):
191- return (type (self ), (self .devices , self .axis_names ))
205+ return (type (self ), (self .devices , self .axis_names , self . axis_types ))
192206
193207 def __eq__ (self , other ):
194208 if not isinstance (other , Mesh ):
@@ -199,12 +213,14 @@ def __eq__(self, other):
199213 return True
200214 return (self .axis_names == other .axis_names and
201215 self .devices .shape == other .devices .shape and
216+ self ._axis_types_tuple == other ._axis_types_tuple and
202217 self ._internal_device_list == other ._internal_device_list )
203218
204219 def __hash__ (self ):
205220 if not hasattr (self , '_hash' ):
206221 self ._hash = hash (
207- (self .axis_names , self ._internal_device_list , self .devices .shape ))
222+ (self .axis_names , self ._internal_device_list , self .devices .shape ,
223+ self ._axis_types_tuple ))
208224 return self ._hash
209225
210226 def __setattr__ (self , name , value ):
@@ -301,7 +317,8 @@ def __str__(self):
301317 def _repr (self ):
302318 if self .empty :
303319 return "Mesh(device_ids=[], axis_names=())"
304- return f"Mesh(device_ids={ self .device_ids !r} , axis_names={ self .axis_names !r} )"
320+ atr = '' if self .axis_types is None else f", axis_types={ self .axis_types } "
321+ return f"Mesh(device_ids={ self .device_ids !r} , axis_names={ self .axis_names !r} { atr } )"
305322
306323 def __repr__ (self ):
307324 return self ._repr
@@ -313,7 +330,7 @@ def local_devices(self):
313330
314331 @functools .cached_property
315332 def abstract_mesh (self ):
316- return AbstractMesh (self .shape_tuple )
333+ return AbstractMesh (self .shape_tuple , self . axis_types )
317334
318335
319336EMPTY_ENV = ResourceEnv (Mesh (np .empty ((), dtype = object ), ()))
@@ -338,25 +355,32 @@ class AbstractMesh:
338355 details.
339356 """
340357
341- def __init__ (self , shape_tuple : tuple [tuple [str , int ], ...]):
358+ def __init__ (self , shape_tuple : tuple [tuple [str , int ], ...],
359+ axis_types : dict [AxisTypes , str | tuple [str , ...]] | None = None ):
342360 self .shape_tuple = shape_tuple
361+ self .axis_types = axis_types
343362 if self .shape_tuple :
344363 self ._axis_names , self ._axis_sizes = list (zip (* self .shape_tuple ))
345364 else :
346365 self ._axis_names , self ._axis_sizes = (), ()
366+ # TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
367+ self ._axis_types_tuple = (None if axis_types is None else
368+ tuple (axis_types .items ()))
347369
348370 def __hash__ (self ):
349- return hash (self .shape_tuple )
371+ return hash (( self .shape_tuple , self . _axis_types_tuple ) )
350372
351373 def __eq__ (self , other ):
352374 if not isinstance (other , AbstractMesh ):
353375 return False
354376 if id (self ) == id (other ):
355377 return True
356- return self .shape_tuple == other .shape_tuple
378+ return (self .shape_tuple == other .shape_tuple and
379+ self ._axis_types_tuple == other ._axis_types_tuple )
357380
358381 def __repr__ (self ):
359- return f"AbstractMesh({ self .shape_tuple } )"
382+ atr = '' if self .axis_types is None else f", axis_types={ self .axis_types } "
383+ return f"AbstractMesh({ self .shape_tuple } { atr } )"
360384
361385 @property
362386 def axis_names (self ):
@@ -382,6 +406,12 @@ def _internal_device_list(self):
382406 def empty (self ):
383407 return self .size == 0
384408
409+ @functools .cached_property
410+ def are_all_axes_collective (self ) -> bool :
411+ if self .axis_types is None :
412+ return False
413+ return all (t == AxisTypes .Collective for t in self .axis_types .keys ())
414+
385415 @property
386416 def devices (self ):
387417 _raise_value_error ("devices" )
0 commit comments