Skip to content

Commit e528562

Browse files
committed
add mutable array ref error checks to scan
1 parent 74eca13 commit e528562

File tree

5 files changed

+59
-32
lines changed

5 files changed

+59
-32
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ pytype_strict_library(
344344
":traceback_util",
345345
":tree_util",
346346
":util",
347+
":state_types",
347348
] + py_deps("numpy"),
348349
)
349350

jax/_src/api_util.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import numpy as np
2424

2525
from jax._src import core
26+
from jax._src import config
2627
from jax._src import dtypes
28+
from jax._src.state.types import AbstractRef
2729
from jax._src.abstract_arrays import numpy_scalar_types
2830
from jax._src.core import ShapedArray
2931
from jax._src.tree_util import (
@@ -737,3 +739,31 @@ def __eq__(self, other):
737739
def register_class_with_attrs(t: type) -> None:
738740
_class_with_attrs.add(t)
739741
_class_with_attrs: set[type] = set()
742+
743+
# TODO(mattjj): make this function faster
744+
def _check_no_aliased_ref_args(dbg, avals, args):
745+
assert config.mutable_array_checks.value
746+
refs: dict[int, int] = {}
747+
for i, (a, x) in enumerate(zip(avals, args)):
748+
if (isinstance(a, AbstractRef) and
749+
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
750+
raise ValueError(
751+
"only one reference to a mutable array may be passed as an argument "
752+
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
753+
f"the mutable array reference of type {a.str_short()} appeared at both "
754+
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
755+
if dbg else
756+
f"at both flat index {dup_idx} and flat index {i}") from None
757+
758+
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
759+
assert config.mutable_array_checks.value
760+
refs: set[int] = {id(core.get_referent(c)) for c in consts
761+
if isinstance(core.get_aval(c), AbstractRef)}
762+
for i, x in enumerate(args):
763+
if id(core.get_referent(x)) in refs:
764+
a = shaped_abstractify(x)
765+
raise ValueError(
766+
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
767+
f"array reference of type {a.str_short()} was both closed over and "
768+
f"passed as the argument "
769+
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")

jax/_src/lax/control_flow/loops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from jax._src import source_info_util
3535
from jax._src import state
3636
from jax._src import util
37-
from jax._src.api_util import shaped_abstractify
37+
from jax._src.api_util import (
38+
shaped_abstractify, _check_no_aliased_ref_args,
39+
_check_no_aliased_closed_over_refs)
3840
from jax._src.core import ShapedArray
3941
from jax._src.interpreters import ad
4042
from jax._src.interpreters import batching
@@ -271,13 +273,20 @@ def scan(f, init, xs, length=None):
271273
xs_avals = [core.get_aval(x) for x in xs_flat]
272274
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
273275

276+
if config.mutable_array_checks.value:
277+
in_flat, in_tree = tree_flatten((init, xs))
278+
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
279+
in_avals = tuple(_map(core.get_aval, in_flat))
280+
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
281+
274282
def _create_jaxpr(init):
275283
init_flat, init_tree = tree_flatten(init)
276284
in_flat, in_tree = tree_flatten((init, xs))
277-
278285
carry_avals = tuple(_map(core.get_aval, init_flat))
279286
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
280287
f, in_tree, (*carry_avals, *x_avals), "scan")
288+
if config.mutable_array_checks.value:
289+
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat)
281290
out_tree_children = out_tree.children()
282291
if len(out_tree_children) != 2:
283292
msg = "scan body output must be a pair, got {}."

jax/_src/pjit.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
5050
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
5151
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
52-
hoist_obj_attrs)
52+
hoist_obj_attrs, _check_no_aliased_ref_args,
53+
_check_no_aliased_closed_over_refs)
5354
from jax._src.interpreters import partial_eval as pe
5455
from jax._src.partition_spec import PartitionSpec
5556
from jax._src.interpreters import xla
@@ -627,7 +628,8 @@ def _infer_params_impl(
627628
flat_fun, in_type, attr_token, dbg,
628629
HashableFunction(res_paths, closure=()),
629630
IgnoreKey(ji.inline))
630-
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
631+
if config.mutable_array_checks.value:
632+
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
631633
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
632634

633635
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
@@ -764,33 +766,9 @@ def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]
764766
" static_argnums or static_argnames parameters of jax.jit."
765767
) from None
766768
if config.mutable_array_checks.value:
767-
# TODO(mattjj): make this faster
768-
refs: dict[int, int] = {}
769-
for i, (a, x) in enumerate(zip(avals, explicit_args)):
770-
if (isinstance(a, AbstractRef) and
771-
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
772-
raise ValueError(
773-
"only one reference to a mutable array may be passed as an argument "
774-
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
775-
f"the mutable array reference of type {a.str_short()} appeared at both "
776-
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
777-
if dbg else
778-
f"at both flat index {dup_idx} and flat index {i}") from None
769+
_check_no_aliased_ref_args(dbg, avals, explicit_args)
779770
return tuple(avals)
780771

781-
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
782-
if not config.mutable_array_checks.value: return
783-
refs: set[int] = {id(core.get_referent(c)) for c in consts
784-
if isinstance(core.get_aval(c), AbstractRef)}
785-
for i, x in enumerate(args):
786-
if id(core.get_referent(x)) in refs:
787-
a = shaped_abstractify(x)
788-
raise ValueError(
789-
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
790-
f"array reference of type {a.str_short()} was both closed over and "
791-
f"passed as the argument "
792-
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
793-
794772
def _extract_implicit_args(
795773
in_type: Sequence[tuple[core.AbstractValue, bool]],
796774
explicit_args: Sequence[Any]

tests/mutable_array_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ def test_scan_scanned_mut_array(self, jit):
206206
def body_fun(_, index_x):
207207
(index, x) = index_x
208208
x[...] += index
209-
# breakpoint()
210209
return ((), x[...])
211210

212211
x_mut = core.mutable_array(np.arange(5))
@@ -289,8 +288,18 @@ def test_return_from_scan(self):
289288
ValueError, "traced for scan returned a mutable array reference of type"):
290289
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))
291290

292-
# TODO test_argument_aliases_scan
293-
# TODO test_closure_and_argument_aliases_scan
291+
def test_argument_aliases_scan(self):
292+
x_ref = core.mutable_array(0.)
293+
with self.assertRaisesRegex(
294+
ValueError, r"appeared at both c\[0\] and c\[1\]"):
295+
jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1)
296+
297+
def test_closure_and_argument_aliases_scan(self):
298+
x_ref = core.mutable_array(0.)
299+
with self.assertRaisesRegex(
300+
ValueError, r"closed over and passed as the argument y_ref"):
301+
jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref,
302+
None, length=1)
294303

295304
def test_return_from_cond(self):
296305
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)