@@ -1388,7 +1388,7 @@ def lattice_join(x, y):
13881388
13891389def valid_jaxtype (x ) -> bool :
13901390 try :
1391- concrete_aval (x )
1391+ abstractify (x )
13921392 except TypeError :
13931393 return False
13941394 else :
@@ -1400,35 +1400,9 @@ def check_valid_jaxtype(x):
14001400 f"Value { x !r} of type { type (x )} is not a valid JAX type" )
14011401
14021402
1403- # TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
1404- # This is tricky because concrete_aval includes sharding information, and
1405- # abstractify does not; further, because abstractify is in the dispatch path,
1406- # performance is important and simply adding sharding there is not an option.
1407- def concrete_aval (x ):
1408- # This differs from abstractify below in that the abstract values
1409- # include sharding where applicable. Historically (before stackless)
1410- # the returned avals were concrete, but after the stackless change
1411- # this returns ShapedArray like abstractify.
1412- # Rules are registered in pytype_aval_mappings.
1413- for typ in type (x ).__mro__ :
1414- handler = pytype_aval_mappings .get (typ )
1415- if handler : return handler (x )
1416- if hasattr (x , '__jax_array__' ):
1417- return concrete_aval (x .__jax_array__ ())
1418- raise TypeError (f"Value { x !r} with type { type (x )} is not a valid JAX "
1419- "type" )
1420-
1421-
14221403def abstractify (x ):
1423- # Historically, this was called xla.abstractify. It differs from
1424- # concrete_aval in that it excludes sharding information, and
1425- # uses a more performant path for accessing avals. Rules are
1426- # registered in xla_pytype_aval_mappings.
1427- typ = type (x )
1428- aval_fn = xla_pytype_aval_mappings .get (typ )
1429- if aval_fn : return aval_fn (x )
1430- for typ in typ .__mro__ :
1431- aval_fn = xla_pytype_aval_mappings .get (typ )
1404+ for typ in type (x ).__mro__ :
1405+ aval_fn = pytype_aval_mappings .get (typ )
14321406 if aval_fn : return aval_fn (x )
14331407 if hasattr (x , '__jax_array__' ):
14341408 return abstractify (x .__jax_array__ ())
@@ -1439,7 +1413,7 @@ def get_aval(x):
14391413 if isinstance (x , Tracer ):
14401414 return x .aval
14411415 else :
1442- return concrete_aval (x )
1416+ return abstractify (x )
14431417
14441418get_type = get_aval
14451419
@@ -1835,7 +1809,6 @@ def to_tangent_aval(self):
18351809 self .weak_type )
18361810
18371811pytype_aval_mappings : dict [type , Callable [[Any ], AbstractValue ]] = {}
1838- xla_pytype_aval_mappings : dict [type , Callable [[Any ], AbstractValue ]] = {}
18391812
18401813
18411814class DArray :
@@ -1892,7 +1865,6 @@ def data(self):
18921865
18931866pytype_aval_mappings [DArray ] = \
18941867 lambda x : DShapedArray (x ._aval .shape , x ._aval .dtype , x ._aval .weak_type )
1895- xla_pytype_aval_mappings [DArray ] = lambda x : x ._aval
18961868
18971869@dataclass (frozen = True )
18981870class bint (dtypes .ExtendedDType ):
@@ -1925,7 +1897,6 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
19251897 def __setitem__ (self , idx , x ): return get_aval (self )._setitem (self , idx , x )
19261898 def __repr__ (self ) -> str : return 'Mutable' + repr (self [...])
19271899pytype_aval_mappings [MutableArray ] = lambda x : x ._aval
1928- xla_pytype_aval_mappings [MutableArray ] = lambda x : x ._aval
19291900
19301901def mutable_array (init_val ):
19311902 return mutable_array_p .bind (init_val )
@@ -1979,7 +1950,6 @@ def __init__(self, buf):
19791950 def block_until_ready (self ):
19801951 self ._buf .block_until_ready ()
19811952pytype_aval_mappings [Token ] = lambda _ : abstract_token
1982- xla_pytype_aval_mappings [Token ] = lambda _ : abstract_token
19831953
19841954
19851955# TODO(dougalm): Deprecate these. They're just here for backwards compat.
0 commit comments