@@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape):
18931893
18941894
18951895class ShapedArray (UnshapedArray ):
1896- __slots__ = ['shape' , 'sharding' ] # inherits slots from parent
1896+ __slots__ = ['shape' , 'sharding' , 'varying_manual_axes' ] # inherits slots from parent
18971897 array_abstraction_level = 2
18981898
1899- def __init__ (self , shape , dtype , weak_type = False , * , sharding = None ):
1899+ def __init__ (self , shape , dtype , weak_type = False , * , sharding = None ,
1900+ varying_manual_axes : frozenset [AxisName ] = frozenset ()):
19001901 self .shape = canonicalize_shape (shape )
19011902 self .dtype = _dtype_object (dtype )
19021903 self .weak_type = weak_type
19031904 self .sharding = get_sharding (sharding , self .shape )
1905+ if config .varying_axes_in_types .value :
1906+ self .varying_manual_axes = varying_manual_axes
19041907
19051908 def update (self , shape = None , dtype = None , weak_type = None , ** kwargs ):
19061909 if shape is None :
@@ -1911,6 +1914,9 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
19111914 weak_type = self .weak_type
19121915 if 'sharding' not in kwargs :
19131916 kwargs ['sharding' ] = self .sharding
1917+ if 'varying_manual_axes' not in kwargs :
1918+ kwargs ['varying_manual_axes' ] = getattr (self , 'varying_manual_axes' ,
1919+ frozenset ())
19141920 return ShapedArray (shape , dtype , weak_type , ** kwargs )
19151921
19161922 ndim = property (lambda self : len (self .shape ))
@@ -1927,17 +1933,22 @@ def __eq__(self, other):
19271933 return (type (self ) is type (other )
19281934 and self .dtype == other .dtype and self .shape == other .shape
19291935 and self .weak_type == other .weak_type
1930- and self .sharding == other .sharding )
1936+ and self .sharding == other .sharding
1937+ and (getattr (self , 'varying_manual_axes' , frozenset ()) ==
1938+ getattr (other , 'varying_manual_axes' , frozenset ())))
19311939
19321940 def __hash__ (self ):
19331941 # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
19341942 # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
19351943 # the unique character code via hash(self.dtype.char)
1936- return hash ((self .shape , self .dtype , self .weak_type , self .sharding ))
1944+ return hash ((self .shape , self .dtype , self .weak_type , self .sharding ,
1945+ getattr (self , 'varying_manual_axes' , frozenset ())))
19371946
19381947 def to_tangent_aval (self ):
1939- return ShapedArray (self .shape , primal_dtype_to_tangent_dtype (self .dtype ),
1940- self .weak_type , sharding = self .sharding )
1948+ return ShapedArray (
1949+ self .shape , primal_dtype_to_tangent_dtype (self .dtype ),
1950+ self .weak_type , sharding = self .sharding ,
1951+ varying_manual_axes = getattr (self , 'varying_manual_axes' , frozenset ()))
19411952
19421953 def str_short (self , short_dtypes = False , mesh_axis_types = False ):
19431954 dt_str = (dtypes .short_dtype_name (self .dtype ) if short_dtypes else
0 commit comments