Skip to content

Commit 563c3e2

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
PiperOrigin-RevId: 741516445
1 parent f1ebb1e commit 563c3e2

File tree

5 files changed

+41
-21
lines changed

5 files changed

+41
-21
lines changed

jax/_src/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,7 @@ def check_valid_jaxtype(x):
15281528

15291529
def update_aval_with_sharding(aval, sharding):
15301530
if isinstance(sharding, NamedSharding):
1531-
aval = aval.update(sharding=NamedSharding(
1531+
return aval.update(sharding=NamedSharding(
15321532
sharding.mesh.abstract_mesh,
15331533
sharding.spec._normalized_spec_for_aval(aval.ndim)))
15341534
return aval
@@ -1659,8 +1659,10 @@ def physical_aval(aval):
16591659
elt_aval = physical_element_aval(aval.dtype)
16601660
if isinstance(aval, ShapedArray):
16611661
from jax._src.sharding_impls import physical_sharding # type: ignore
1662+
vma = aval.vma if config.varying_axes_in_types.value else frozenset()
16621663
return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype,
1663-
sharding=physical_sharding(aval, aval.sharding))
1664+
sharding=physical_sharding(aval, aval.sharding),
1665+
vma=vma)
16641666
return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype)
16651667
return aval
16661668

@@ -2019,6 +2021,8 @@ def standard_insert_pbroadcast(*args):
20192021
if out_vma - src else arg for arg, src in zip(args, in_vma)]
20202022

20212023
def standard_vma_rule(prim_name, *avals, **kwargs):
2024+
if not avals:
2025+
return avals
20222026
vma, *vmas = [a.vma for a in avals]
20232027
if not all(vma == vma_ for vma_ in vmas):
20242028
raise ValueError(

jax/_src/ffi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import jax
2626
from jax._src import core
27+
from jax._src import config
2728
from jax._src import deprecations
2829
from jax._src import dispatch
2930
from jax._src import effects
@@ -515,7 +516,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any):
515516
"and an output with a different layout "
516517
f"{static_output_layouts[o_idx]}.")
517518
static_input_output_aliases += ((i_idx, o_idx),)
518-
519+
args = core.standard_insert_pbroadcast(*args)
519520
results = ffi_call_p.bind(
520521
*args,
521522
result_avals=result_avals,
@@ -638,9 +639,11 @@ def ffi_call_abstract_eval(
638639
has_side_effect: bool,
639640
**_,
640641
):
641-
del avals_in # unused
642+
out_vma = (core.standard_vma_rule('ffi_call', *avals_in)
643+
if config.varying_axes_in_types.value else frozenset())
642644
effects = {_FfiEffect} if has_side_effect else core.no_effects
643-
return result_avals, effects
645+
return tuple(r if r is core.abstract_token else r.update(vma=out_vma)
646+
for r in result_avals), effects
644647

645648

646649
def ffi_call_jvp(*args, target_name, **_):

jax/_src/lax/control_flow/solves.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jax._src import api
2424
from jax._src import api_util
2525
from jax._src import core
26+
from jax._src import config
2627
from jax._src import custom_derivatives
2728
from jax._src import linear_util as lu
2829
from jax._src.interpreters import ad
@@ -309,24 +310,25 @@ def f_aux(x):
309310
jaxprs = _LinearSolveTuple(
310311
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
311312

312-
out_flat = linear_solve_p.bind(
313-
*(_flatten(all_consts) + b_flat),
314-
const_lengths=const_lengths, jaxprs=jaxprs)
313+
args = _flatten(all_consts) + b_flat
314+
args = core.standard_insert_pbroadcast(*args)
315+
out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs)
315316

316317
return tree_unflatten(out_tree, out_flat)
317318

318319

319320
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
320321
args_to_raise = args[sum(const_lengths):]
321-
322322
# raise aux_args to shaped arrays as well if present
323323
# number of aux args is the difference in out_avals
324324
# of solve and matvec (since they map to the same vector space)
325-
326325
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
327326
if num_aux > 0:
328327
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
329-
return args_to_raise, jaxprs.solve.effects
328+
out_vma = (core.standard_vma_rule('linear_solve', *args_to_raise)
329+
if config.varying_axes_in_types.value else frozenset())
330+
return (tuple(a.update(vma=out_vma) for a in args_to_raise),
331+
jaxprs.solve.effects)
330332

331333

332334
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):

jax/_src/lax/windowed_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _reduce_window_abstract_eval_rule(
338338
out_sharding = reduce_window_sharding_rule(
339339
operand_avals[0], window_dimensions, window_strides, padding,
340340
base_dilation, window_dilation)
341-
out_vma = (core.standard_vma_rule('reduce_window', operand_avals)
341+
out_vma = (core.standard_vma_rule('reduce_window', *operand_avals)
342342
if config.varying_axes_in_types.value else frozenset())
343343
return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding,
344344
vma=out_vma)

jax/_src/prng.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def copy_to_host_async(self):
178178
def aval(self):
179179
logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding')
180180
else None)
181-
return keys_shaped_array(self._impl, self.shape, logical_sharding)
181+
vma = (self._base_array.aval.vma if config.varying_axes_in_types.value else frozenset()
182+
if hasattr(self._base_array, 'aval') else frozenset())
183+
return keys_shaped_array(self._impl, self.shape, logical_sharding, vma)
182184

183185
@property
184186
def shape(self):
@@ -329,8 +331,8 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray
329331
return random_seed(seed, impl=impl)
330332

331333

332-
def keys_shaped_array(impl, shape, sharding):
333-
aval = core.ShapedArray(shape, KeyTy(impl))
334+
def keys_shaped_array(impl, shape, sharding, vma):
335+
aval = core.ShapedArray(shape, KeyTy(impl), vma=vma)
334336
return core.update_aval_with_sharding(aval, sharding)
335337

336338
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
@@ -550,7 +552,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
550552

551553
@random_seed_p.def_abstract_eval
552554
def random_seed_abstract_eval(seeds_aval, *, impl):
553-
return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding)
555+
out_vma = seeds_aval.vma if config.varying_axes_in_types.value else frozenset()
556+
return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, out_vma)
554557

555558
@random_seed_p.def_impl
556559
def random_seed_impl(seeds, *, impl):
@@ -584,8 +587,9 @@ def random_split_abstract_eval(keys_aval, *, shape):
584587
# TODO(yashkatariya): random_split should take sharding as an arg too so we
585588
# don't choose None here?
586589
new_spec = (*keys_aval.sharding.spec, *[None] * len(shape))
590+
out_vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset()
587591
return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape),
588-
keys_aval.sharding.with_spec(new_spec))
592+
keys_aval.sharding.with_spec(new_spec), out_vma)
589593

590594
@random_split_p.def_impl
591595
def random_split_impl(keys, *, shape):
@@ -611,7 +615,9 @@ def random_split_lowering(ctx, keys, *, shape):
611615

612616

613617
def random_fold_in(keys, msgs):
614-
return random_fold_in_p.bind(keys, jnp.asarray(msgs))
618+
msgs = jnp.asarray(msgs)
619+
keys, msgs = core.standard_insert_pbroadcast(keys, msgs)
620+
return random_fold_in_p.bind(keys, msgs)
615621

616622
random_fold_in_p = core.Primitive('random_fold_in')
617623
ad.defjvp_zero(random_fold_in_p)
@@ -623,7 +629,9 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
623629
'random_fold_in', keys_aval, msgs_aval)
624630
sharding = lax_internal.broadcasting_sharding_rule(
625631
'random_fold_in', keys_aval, msgs_aval)
626-
return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding)
632+
vma = (core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval)
633+
if config.varying_axes_in_types.value else frozenset())
634+
return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma)
627635

628636
@random_fold_in_p.def_impl
629637
def random_fold_in_impl(keys, msgs):
@@ -661,7 +669,8 @@ def random_bits(keys, bit_width, shape):
661669
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
662670
out_shape = (*keys_aval.shape, *shape)
663671
out_dtype = dtypes.dtype(f'uint{bit_width}')
664-
return core.ShapedArray(out_shape, out_dtype)
672+
vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset()
673+
return core.ShapedArray(out_shape, out_dtype, vma=vma)
665674

666675
@random_bits_p.def_impl
667676
def random_bits_impl(keys, *, bit_width, shape):
@@ -718,7 +727,9 @@ def random_wrap(base_arr, *, impl):
718727
def random_wrap_abstract_eval(base_arr_aval, *, impl):
719728
shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape)
720729
sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding)
721-
return keys_shaped_array(impl, shape, sharding)
730+
out_vma = (base_arr_aval.vma if config.varying_axes_in_types.value else
731+
frozenset())
732+
return keys_shaped_array(impl, shape, sharding, out_vma)
722733

723734
@random_wrap_p.def_impl
724735
def random_wrap_impl(base_arr, *, impl):

0 commit comments

Comments
 (0)