@@ -178,7 +178,9 @@ def copy_to_host_async(self):
178178 def aval (self ):
179179 logical_sharding = (self .sharding if hasattr (self ._base_array , 'sharding' )
180180 else None )
181- return keys_shaped_array (self ._impl , self .shape , logical_sharding )
181+ vma = (self ._base_array .aval .vma if config .varying_axes_in_types .value else frozenset ()
182+ if hasattr (self ._base_array , 'aval' ) else frozenset ())
183+ return keys_shaped_array (self ._impl , self .shape , logical_sharding , vma )
182184
183185 @property
184186 def shape (self ):
@@ -329,8 +331,8 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray
329331 return random_seed (seed , impl = impl )
330332
331333
332- def keys_shaped_array (impl , shape , sharding ):
333- aval = core .ShapedArray (shape , KeyTy (impl ))
334+ def keys_shaped_array (impl , shape , sharding , vma ):
335+ aval = core .ShapedArray (shape , KeyTy (impl ), vma = vma )
334336 return core .update_aval_with_sharding (aval , sharding )
335337
336338def base_arr_shape_to_keys_shape (impl , base_arr_shape ):
@@ -550,7 +552,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
550552
551553@random_seed_p .def_abstract_eval
552554def random_seed_abstract_eval (seeds_aval , * , impl ):
553- return keys_shaped_array (impl , seeds_aval .shape , seeds_aval .sharding )
555+ out_vma = seeds_aval .vma if config .varying_axes_in_types .value else frozenset ()
556+ return keys_shaped_array (impl , seeds_aval .shape , seeds_aval .sharding , out_vma )
554557
555558@random_seed_p .def_impl
556559def random_seed_impl (seeds , * , impl ):
@@ -584,8 +587,9 @@ def random_split_abstract_eval(keys_aval, *, shape):
584587 # TODO(yashkatariya): random_split should take sharding as an arg too so we
585588 # don't choose None here?
586589 new_spec = (* keys_aval .sharding .spec , * [None ] * len (shape ))
590+ out_vma = keys_aval .vma if config .varying_axes_in_types .value else frozenset ()
587591 return keys_shaped_array (keys_aval .dtype ._impl , (* keys_aval .shape , * shape ),
588- keys_aval .sharding .with_spec (new_spec ))
592+ keys_aval .sharding .with_spec (new_spec ), out_vma )
589593
590594@random_split_p .def_impl
591595def random_split_impl (keys , * , shape ):
@@ -611,7 +615,9 @@ def random_split_lowering(ctx, keys, *, shape):
611615
612616
613617def random_fold_in (keys , msgs ):
614- return random_fold_in_p .bind (keys , jnp .asarray (msgs ))
618+ msgs = jnp .asarray (msgs )
619+ keys , msgs = core .standard_insert_pbroadcast (keys , msgs )
620+ return random_fold_in_p .bind (keys , msgs )
615621
616622random_fold_in_p = core .Primitive ('random_fold_in' )
617623ad .defjvp_zero (random_fold_in_p )
@@ -623,7 +629,9 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
623629 'random_fold_in' , keys_aval , msgs_aval )
624630 sharding = lax_internal .broadcasting_sharding_rule (
625631 'random_fold_in' , keys_aval , msgs_aval )
626- return core .ShapedArray (shape , keys_aval .dtype , sharding = sharding )
632+ vma = (core .standard_vma_rule ('random_fold_in' , keys_aval , msgs_aval )
633+ if config .varying_axes_in_types .value else frozenset ())
634+ return core .ShapedArray (shape , keys_aval .dtype , sharding = sharding , vma = vma )
627635
628636@random_fold_in_p .def_impl
629637def random_fold_in_impl (keys , msgs ):
@@ -661,7 +669,8 @@ def random_bits(keys, bit_width, shape):
661669def random_bits_abstract_eval (keys_aval , * , bit_width , shape ):
662670 out_shape = (* keys_aval .shape , * shape )
663671 out_dtype = dtypes .dtype (f'uint{ bit_width } ' )
664- return core .ShapedArray (out_shape , out_dtype )
672+ vma = keys_aval .vma if config .varying_axes_in_types .value else frozenset ()
673+ return core .ShapedArray (out_shape , out_dtype , vma = vma )
665674
666675@random_bits_p .def_impl
667676def random_bits_impl (keys , * , bit_width , shape ):
@@ -718,7 +727,9 @@ def random_wrap(base_arr, *, impl):
718727def random_wrap_abstract_eval (base_arr_aval , * , impl ):
719728 shape = base_arr_shape_to_keys_shape (impl , base_arr_aval .shape )
720729 sharding = logical_sharding (shape , KeyTy (impl ), base_arr_aval .sharding )
721- return keys_shaped_array (impl , shape , sharding )
730+ out_vma = (base_arr_aval .vma if config .varying_axes_in_types .value else
731+ frozenset ())
732+ return keys_shaped_array (impl , shape , sharding , out_vma )
722733
723734@random_wrap_p .def_impl
724735def random_wrap_impl (base_arr , * , impl ):
0 commit comments