@@ -1400,7 +1400,16 @@ def check_valid_jaxtype(x):
14001400 f"Value { x !r} of type { type (x )} is not a valid JAX type" )
14011401
14021402
1403+ # TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
1404+ # This is tricky because concrete_aval includes sharding information, and
1405+ # abstractify does not; further, because abstractify is in the dispatch path,
1406+ # performance is important and simply adding sharding there is not an option.
14031407def concrete_aval (x ):
1408+ # This differs from abstractify below in that the abstract values
1409+ # include sharding where applicable. Historically (before stackless)
1410+ # the returned avals were concrete, but after the stackless change
1411+ # this returns ShapedArray like abstractify.
1412+ # Rules are registered in pytype_aval_mappings.
14041413 for typ in type (x ).__mro__ :
14051414 handler = pytype_aval_mappings .get (typ )
14061415 if handler : return handler (x )
@@ -1410,6 +1419,22 @@ def concrete_aval(x):
14101419 "type" )
14111420
14121421
1422+ def abstractify (x ):
1423+ # Historically, this was called xla.abstractify. It differs from
1424+ # concrete_aval in that it excludes sharding information, and
1425+ # uses a more performant path for accessing avals. Rules are
1426+ # registered in xla_pytype_aval_mappings.
1427+ typ = type (x )
1428+ aval_fn = xla_pytype_aval_mappings .get (typ )
1429+ if aval_fn : return aval_fn (x )
1430+ for typ in typ .__mro__ :
1431+ aval_fn = xla_pytype_aval_mappings .get (typ )
1432+ if aval_fn : return aval_fn (x )
1433+ if hasattr (x , '__jax_array__' ):
1434+ return abstractify (x .__jax_array__ ())
1435+ raise TypeError (f"Argument '{ x } ' of type '{ type (x )} ' is not a valid JAX type" )
1436+
1437+
14131438def get_aval (x ):
14141439 if isinstance (x , Tracer ):
14151440 return x .aval
@@ -1810,6 +1835,7 @@ def to_tangent_aval(self):
18101835 self .weak_type )
18111836
18121837pytype_aval_mappings : dict [type , Callable [[Any ], AbstractValue ]] = {}
1838+ xla_pytype_aval_mappings : dict [type , Callable [[Any ], AbstractValue ]] = {}
18131839
18141840
18151841class DArray :
@@ -1866,6 +1892,7 @@ def data(self):
18661892
18671893pytype_aval_mappings [DArray ] = \
18681894 lambda x : DShapedArray (x ._aval .shape , x ._aval .dtype , x ._aval .weak_type )
1895+ xla_pytype_aval_mappings [DArray ] = lambda x : x ._aval
18691896
18701897@dataclass (frozen = True )
18711898class bint (dtypes .ExtendedDType ):
@@ -1898,6 +1925,7 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
18981925 def __setitem__ (self , idx , x ): return get_aval (self )._setitem (self , idx , x )
18991926 def __repr__ (self ) -> str : return 'Mutable' + repr (self [...])
19001927pytype_aval_mappings [MutableArray ] = lambda x : x ._aval
1928+ xla_pytype_aval_mappings [MutableArray ] = lambda x : x ._aval
19011929
19021930def mutable_array (init_val ):
19031931 return mutable_array_p .bind (init_val )
@@ -1951,6 +1979,7 @@ def __init__(self, buf):
19511979 def block_until_ready (self ):
19521980 self ._buf .block_until_ready ()
19531981pytype_aval_mappings [Token ] = lambda _ : abstract_token
1982+ xla_pytype_aval_mappings [Token ] = lambda _ : abstract_token
19541983
19551984
19561985# TODO(dougalm): Deprecate these. They're just here for backwards compat.
0 commit comments