@@ -1388,7 +1388,7 @@ def lattice_join(x, y):
13881388
13891389def valid_jaxtype (x ) -> bool :
13901390 try :
1391- abstractify (x )
1391+ concrete_aval (x )
13921392 except TypeError :
13931393 return False
13941394 else :
@@ -1400,9 +1400,35 @@ def check_valid_jaxtype(x):
14001400 f"Value { x !r} of type { type (x )} is not a valid JAX type" )
14011401
14021402
1403- def abstractify (x ):
1403+ # TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
1404+ # This is tricky because concrete_aval includes sharding information, and
1405+ # abstractify does not; further, because abstractify is in the dispatch path,
1406+ # performance is important and simply adding sharding there is not an option.
1407+ def concrete_aval (x ):
1408+ # This differs from abstractify below in that the abstract values
1409+ # include sharding where applicable. Historically (before stackless)
1410+ # the returned avals were concrete, but after the stackless change
1411+ # this returns ShapedArray like abstractify.
1412+ # Rules are registered in pytype_aval_mappings.
14041413 for typ in type (x ).__mro__ :
1405- aval_fn = pytype_aval_mappings .get (typ )
1414+ handler = pytype_aval_mappings .get (typ )
1415+ if handler : return handler (x )
1416+ if hasattr (x , '__jax_array__' ):
1417+ return concrete_aval (x .__jax_array__ ())
1418+ raise TypeError (f"Value { x !r} with type { type (x )} is not a valid JAX "
1419+ "type" )
1420+
1421+
1422+ def abstractify (x ):
1423+ # Historically, this was called xla.abstractify. It differs from
1424+ # concrete_aval in that it excludes sharding information, and
1425+ # uses a more performant path for accessing avals. Rules are
1426+ # registered in xla_pytype_aval_mappings.
1427+ typ = type (x )
1428+ aval_fn = xla_pytype_aval_mappings .get (typ )
1429+ if aval_fn : return aval_fn (x )
1430+ for typ in typ .__mro__ :
1431+ aval_fn = xla_pytype_aval_mappings .get (typ )
14061432 if aval_fn : return aval_fn (x )
14071433 if hasattr (x , '__jax_array__' ):
14081434 return abstractify (x .__jax_array__ ())
@@ -1413,7 +1439,7 @@ def get_aval(x):
14131439 if isinstance (x , Tracer ):
14141440 return x .aval
14151441 else :
1416- return abstractify (x )
1442+ return concrete_aval (x )
14171443
14181444get_type = get_aval
14191445
@@ -1809,6 +1835,7 @@ def to_tangent_aval(self):
18091835 self .weak_type )
18101836
18111837pytype_aval_mappings : dict [type , Callable [[Any ], AbstractValue ]] = {}
1838+ xla_pytype_aval_mappings : dict [type , Callable [[Any ], AbstractValue ]] = {}
18121839
18131840
18141841class DArray :
@@ -1865,6 +1892,7 @@ def data(self):
18651892
18661893pytype_aval_mappings [DArray ] = \
18671894 lambda x : DShapedArray (x ._aval .shape , x ._aval .dtype , x ._aval .weak_type )
1895+ xla_pytype_aval_mappings [DArray ] = lambda x : x ._aval
18681896
18691897@dataclass (frozen = True )
18701898class bint (dtypes .ExtendedDType ):
@@ -1897,6 +1925,7 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
18971925 def __setitem__ (self , idx , x ): return get_aval (self )._setitem (self , idx , x )
18981926 def __repr__ (self ) -> str : return 'Mutable' + repr (self [...])
18991927pytype_aval_mappings [MutableArray ] = lambda x : x ._aval
1928+ xla_pytype_aval_mappings [MutableArray ] = lambda x : x ._aval
19001929
19011930def mutable_array (init_val ):
19021931 return mutable_array_p .bind (init_val )
@@ -1952,6 +1981,7 @@ def __init__(self, buf):
19521981 def block_until_ready (self ):
19531982 self ._buf .block_until_ready ()
19541983pytype_aval_mappings [Token ] = lambda _ : abstract_token
1984+ xla_pytype_aval_mappings [Token ] = lambda _ : abstract_token
19551985
19561986
19571987# TODO(dougalm): Deprecate these. They're just here for backwards compat.
0 commit comments