Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def __repr__(self) -> str:
return f'Zero({self.aval})'
@staticmethod
def from_primal_value(val: Any) -> Zero:
# TODO(mattjj,yashkatariya): sometimes we want to_cotangent_aval...
return Zero(get_aval(val).to_tangent_aval())
def instantiate(self):
return zeros_like_aval(self.aval)

register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))

Expand Down
2 changes: 2 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,8 @@ def vjp3(f, *primals, has_aux=False):

def _vjp3(fun, *primals, has_aux=False):
primals_flat, in_tree = tree_flatten(primals)
primals_flat = [dtypes.canonicalize_value(v) if not isinstance(v, core.Tracer)
else v for v in primals_flat]
for arg in primals_flat: dispatch.check_arg(arg)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
Expand Down
1 change: 0 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,7 +2622,6 @@ class ArrayRefImpl:
def __init__(self, aval, buf):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
assert isinstance(aval, AbstractRef) and isinstance(aval.inner_aval, ShapedArray)
assert isinstance(buf, Array)
self._aval = aval
self._buf = buf

Expand Down
14 changes: 7 additions & 7 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,13 @@ def read(x: core.Atom) -> Array | GradAccum:
acc.accum(ct) # jaxpr.outvars can have Literals, env can have inst zeros
with ctx:
for eqn in lin_eqns[::-1]:
if eqn.primitive.ref_primitive:
ct = env.pop(eqn.outvars[0]).freeze()
acc = read(eqn.invars[0])
if isinstance(acc, GradAccum):
acc.accum(ct)
else:
with eqn.ctx.manager, _name_stack_ctx(eqn.source_info):
with eqn.ctx.manager, _name_stack_ctx(eqn.source_info):
if eqn.primitive.ref_primitive:
ct = env.pop(eqn.outvars[0]).freeze()
acc = read(eqn.invars[0])
if isinstance(acc, GradAccum):
acc.accum(ct)
else:
cts_in = [env.pop(v).freeze() for v in eqn.outvars]
if not eqn.primitive.multiple_results:
cts_in, = cts_in
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,10 @@ def broadcast(x, sz, axis, mesh_axis):
x = core.pvary(x, tuple(spmd_names))
return x

def matchaxis2(axis_data, src, dst, x, sum_match=False):
return matchaxis(axis_data.name, axis_data.size, axis_data.explicit_mesh_axis,
src, dst, x, sum_match)

def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
if dst == jumble_axis:
x = bdim_at_front(x, src, sz)
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,8 @@ def __call__(self, *args):
out_ = []
for i, o in zip(self.mut.out_mut, out):
if i is not None:
args[i]._refs._buf._replace_with(o) # type: ignore
try: args[i]._refs._buf._replace_with(o) # type: ignore
except AttributeError: pass # TODO(mattjj): remove float0
Comment on lines +1398 to +1399

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except AttributeError: pass can be risky as it might silence unexpected errors. While the TODO comment provides context about float0, it would be safer to make the check more specific if possible. For instance, you could check if the argument is of a type that is known not to have the _refs attribute, like float0 arrays. If a specific check is not feasible, consider adding a log warning within the except block to aid in debugging any future, unrelated AttributeErrors that might be suppressed here.

else:
out_.append(o)
return out_
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2968,6 +2968,7 @@ def reshard(xs, out_shardings):
reshard_p.skip_canonicalization = True

def _reshard_abstract_eval(aval, dst_sharding):
assert isinstance(aval, core.ShapedArray)
if aval.sharding == dst_sharding:
return aval
return aval.update(sharding=dst_sharding)
Expand Down
21 changes: 16 additions & 5 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def _swap_transpose_fancy(g, ref_, x, *idx, **params):
ad.fancy_transposes[swap_p] = _swap_transpose_fancy

def addupdate_transpose_fancy(cts_in, ref_, x, *idx, **params):
if ref_.ref is not None:
if ref_.ref is not None and isinstance(x, ad.GradAccum):
x_bar = get_p.bind(ref_.ref, *idx, **params)
x.accum(x_bar)
ad.fancy_transposes[addupdate_p] = addupdate_transpose_fancy
Expand Down Expand Up @@ -706,7 +706,19 @@ def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn):
return eqn, eqn, [False], [True], res # full remat
pe.partial_eval_jaxpr_custom_rules[get_p] = _state_partial_eval_custom
pe.partial_eval_jaxpr_custom_rules[swap_p] = _state_partial_eval_custom
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _state_partial_eval_custom

def _addupdate_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # ignored, always full remat state ops on known inputs
ref_unk, *_ = unks_in
ref_inst, *inst_in = inst_in
_, *val_vars = eqn.invars
assert ref_inst
res = [v for v, inst in zip(val_vars, inst_in) if not inst]
if ref_unk:
return None, eqn, [], [], res # tangent operation
else:
return eqn, eqn, [], [], res # full remat
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _addupdate_partial_eval_custom

## get/swap/addupdate batching rules

Expand Down Expand Up @@ -972,9 +984,8 @@ def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree):
"Move the array reference to be an argument to the vmapped "
"function?")
if not indexers:
if ref_is_batched and not val_is_batched:
val = batching.broadcast(val, axis_data.size, ref_dim,
axis_data.explicit_mesh_axis)
if val_dim != ref_dim:
val = batching.matchaxis2(axis_data, val_dim, ref_dim, val)
return addupdate_p.bind(ref, val, *flat_idxs, tree=tree), []
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
Expand Down
3 changes: 2 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3129,10 +3129,11 @@ def test_float0_reshape(self):
def test_float0_error(self):
# float0 is incompatible with other dtypes
float0_array = jax.grad(lambda x: x+0., allow_int=True)(1)
self.assertEqual(float0_array.dtype, dtypes.float0)
error_text = "float0s do not support any operations by design"

with self.assertRaisesRegex(TypeError, error_text):
# dispatch via Array
# dispatch via Array.__add__ and hence jax.numpy
_ = float0_array + jnp.zeros(())

with self.assertRaisesRegex(TypeError, error_text):
Expand Down
16 changes: 8 additions & 8 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5144,14 +5144,14 @@ def f(x, y):
self.assertEqual(out[1].sharding, arr2.sharding)

jaxpr = jitted_grad.trace(arr1, arr2).jaxpr
bwd_jaxpr = jaxpr.eqns[-1]
expected_spec = [('broadcast_in_dim', P('x', None)),
('dot_general', P('x', None)),
('transpose', P(None, 'x')),
('dot_general', P('x', None))]
for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec):
self.assertEqual(eqn.primitive.name, spec[0])
self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1])
bwd_jaxpr = next(e for e in reversed(jaxpr.eqns) if 'jaxpr' in e.params)
expected_spec = {'broadcast_in_dim': P('x', None),
'dot_general': P('x', None),
'transpose': P(None, 'x')}
for eqn in bwd_jaxpr.params['jaxpr'].eqns:
spec = expected_spec.get(eqn.primitive.name)
if spec is not None:
self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec)

@parameterized.named_parameters(
('fail1', P('x', None), P(None, 'x'),
Expand Down
Loading