diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 41ff990860a1..4b693fb1a890 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -73,7 +73,10 @@ def __repr__(self) -> str: return f'Zero({self.aval})' @staticmethod def from_primal_value(val: Any) -> Zero: + # TODO(mattjj,yashkatariya): sometimes we want to_cotangent_aval... return Zero(get_aval(val).to_tangent_aval()) + def instantiate(self): + return zeros_like_aval(self.aval) register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) diff --git a/jax/_src/api.py b/jax/_src/api.py index f9bfe0fc264a..6fd506e915a4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2288,6 +2288,8 @@ def vjp3(f, *primals, has_aux=False): def _vjp3(fun, *primals, has_aux=False): primals_flat, in_tree = tree_flatten(primals) + primals_flat = [dtypes.canonicalize_value(v) if not isinstance(v, core.Tracer) + else v for v in primals_flat] for arg in primals_flat: dispatch.check_arg(arg) if not has_aux: flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) diff --git a/jax/_src/core.py b/jax/_src/core.py index c81dd820b913..44cf4d249ed9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2622,7 +2622,6 @@ class ArrayRefImpl: def __init__(self, aval, buf): from jax._src.state.types import AbstractRef # pytype: disable=import-error assert isinstance(aval, AbstractRef) and isinstance(aval.inner_aval, ShapedArray) - assert isinstance(buf, Array) self._aval = aval self._buf = buf diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 13d95d9df14f..5e22eca6f494 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -528,13 +528,13 @@ def read(x: core.Atom) -> Array | GradAccum: acc.accum(ct) # jaxpr.outvars can have Literals, env can have inst zeros with ctx: for eqn in lin_eqns[::-1]: - if eqn.primitive.ref_primitive: - ct = env.pop(eqn.outvars[0]).freeze() - acc = read(eqn.invars[0]) - if isinstance(acc, GradAccum): - acc.accum(ct) - else: - with eqn.ctx.manager, _name_stack_ctx(eqn.source_info): + with eqn.ctx.manager, _name_stack_ctx(eqn.source_info): + if eqn.primitive.ref_primitive: + ct = env.pop(eqn.outvars[0]).freeze() + acc = read(eqn.invars[0]) + if isinstance(acc, GradAccum): + acc.accum(ct) + else: cts_in = [env.pop(v).freeze() for v in eqn.outvars] if not eqn.primitive.multiple_results: cts_in, = cts_in diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index db97ef0dbfdc..d6613bad0647 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1221,6 +1221,10 @@ def broadcast(x, sz, axis, mesh_axis): x = core.pvary(x, tuple(spmd_names)) return x +def matchaxis2(axis_data, src, dst, x, sum_match=False): + return matchaxis(axis_data.name, axis_data.size, axis_data.explicit_mesh_axis, + src, dst, x, sum_match) + def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: x = bdim_at_front(x, src, sz) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9fbe85c2060b..de54bea12950 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1395,7 +1395,8 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - args[i]._refs._buf._replace_with(o) # type: ignore + try: args[i]._refs._buf._replace_with(o) # type: ignore + except AttributeError: pass # TODO(mattjj): remove float0 else: out_.append(o) return out_ diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b6d113c30dec..ed568d85cd7e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2968,6 +2968,7 @@ def reshard(xs, out_shardings): reshard_p.skip_canonicalization = True def _reshard_abstract_eval(aval, dst_sharding): + assert isinstance(aval, core.ShapedArray) if aval.sharding == dst_sharding: return aval return aval.update(sharding=dst_sharding) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index fcb69e140788..c50979bc7d53 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -656,7 +656,7 @@ def _swap_transpose_fancy(g, ref_, x, *idx, **params): ad.fancy_transposes[swap_p] = _swap_transpose_fancy def addupdate_transpose_fancy(cts_in, ref_, x, *idx, **params): - if ref_.ref is not None: + if ref_.ref is not None and isinstance(x, ad.GradAccum): x_bar = get_p.bind(ref_.ref, *idx, **params) x.accum(x_bar) ad.fancy_transposes[addupdate_p] = addupdate_transpose_fancy @@ -706,7 +706,19 @@ def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn): return eqn, eqn, [False], [True], res # full remat pe.partial_eval_jaxpr_custom_rules[get_p] = _state_partial_eval_custom pe.partial_eval_jaxpr_custom_rules[swap_p] = _state_partial_eval_custom -pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _state_partial_eval_custom + +def _addupdate_partial_eval_custom(saveable, unks_in, inst_in, eqn): + del saveable # ignored, always full remat state ops on known inputs + ref_unk, *_ = unks_in + ref_inst, *inst_in = inst_in + _, *val_vars = eqn.invars + assert ref_inst + res = [v for v, inst in zip(val_vars, inst_in) if not inst] + if ref_unk: + return None, eqn, [], [], res # tangent operation + else: + return eqn, eqn, [], [], res # full remat +pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _addupdate_partial_eval_custom ## get/swap/addupdate batching rules @@ -972,9 +984,8 @@ def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree): "Move the array reference to be an argument to the vmapped " "function?") if not indexers: - if ref_is_batched and not val_is_batched: - val = batching.broadcast(val, axis_data.size, ref_dim, - axis_data.explicit_mesh_axis) + if val_dim != ref_dim: + val = batching.matchaxis2(axis_data, val_dim, ref_dim, val) return addupdate_p.bind(ref, val, *flat_idxs, tree=tree), [] if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") diff --git a/tests/api_test.py b/tests/api_test.py index c1a22fc7695f..e3de4246e06c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3129,10 +3129,11 @@ def test_float0_reshape(self): def test_float0_error(self): # float0 is incompatible with other dtypes float0_array = jax.grad(lambda x: x+0., allow_int=True)(1) + self.assertEqual(float0_array.dtype, dtypes.float0) error_text = "float0s do not support any operations by design" with self.assertRaisesRegex(TypeError, error_text): - # dispatch via Array + # dispatch via Array.__add__ and hence jax.numpy _ = float0_array + jnp.zeros(()) with self.assertRaisesRegex(TypeError, error_text): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 66be4053b17e..c8de63d80f16 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5144,14 +5144,14 @@ def f(x, y): self.assertEqual(out[1].sharding, arr2.sharding) jaxpr = jitted_grad.trace(arr1, arr2).jaxpr - bwd_jaxpr = jaxpr.eqns[-1] - expected_spec = [('broadcast_in_dim', P('x', None)), - ('dot_general', P('x', None)), - ('transpose', P(None, 'x')), - ('dot_general', P('x', None))] - for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec): - self.assertEqual(eqn.primitive.name, spec[0]) - self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1]) + bwd_jaxpr = next(e for e in reversed(jaxpr.eqns) if 'jaxpr' in e.params) + expected_spec = {'broadcast_in_dim': P('x', None), + 'dot_general': P('x', None), + 'transpose': P(None, 'x')} + for eqn in bwd_jaxpr.params['jaxpr'].eqns: + spec = expected_spec.get(eqn.primitive.name) + if spec is not None: + self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec) @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'),