@@ -345,22 +345,25 @@ def _vprim_impl(*inputs, prim, __axis_size, __axis_name, __batch_axes, params):
345345 return impl (* inputs )
346346
347347
348+ if hasattr (jax .extend .core , "mapped_aval" ):
349+ _mapped_aval = jax .extend .core .mapped_aval # pyright: ignore[reportAttributeAccessIssue]
350+ else :
351+ _mapped_aval = jax .core .mapped_aval
352+ if hasattr (jax .extend .core , "unmapped_aval" ):
353+ _unmapped_aval = jax .extend .core .unmapped_aval # pyright: ignore[reportAttributeAccessIssue,reportAssignmentType]
354+ else :
355+ _unmapped_aval = jax .core .unmapped_aval # pyright: ignore[reportAssignmentType]
348356if jax .__version_info__ >= (0 , 5 , 1 ):
357+ _old_unmapped_aval = _unmapped_aval
349358
350359 def _unmapped_aval (axis_size , axis_name , axis , aval ):
351360 del axis_name
352- return jax .core .unmapped_aval (axis_size , axis , aval ) # pyright: ignore[reportCallIssue]
353-
354- else :
355- # signature (axis_size, axis_name, axis, aval)
356- _unmapped_aval = jax .core .unmapped_aval # pyright: ignore[reportAssignmentType]
361+ return _old_unmapped_aval (axis_size , axis , aval ) # pyright: ignore[reportCallIssue]
357362
358363
359364def _vprim_abstract_eval (* inputs , prim , __axis_size , __axis_name , __batch_axes , params ):
360365 assert len (inputs ) == len (__batch_axes )
361- inputs = [
362- jax .core .mapped_aval (__axis_size , b , x ) for x , b in zip (inputs , __batch_axes )
363- ]
366+ inputs = [_mapped_aval (__axis_size , b , x ) for x , b in zip (inputs , __batch_axes )]
364367 abstract_eval = _vprim_abstract_eval_registry [prim ]
365368 outs = abstract_eval (* inputs , ** dict (params ))
366369 outs = [_unmapped_aval (__axis_size , __axis_name , 0 , x ) for x in outs ]
0 commit comments