|
49 | 49 | argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, |
50 | 50 | donation_vector, shaped_abstractify, check_callable, resolve_argnums, |
51 | 51 | argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, |
52 | | - hoist_obj_attrs) |
| 52 | + hoist_obj_attrs, _check_no_aliased_ref_args, |
| 53 | + _check_no_aliased_closed_over_refs) |
53 | 54 | from jax._src.interpreters import partial_eval as pe |
54 | 55 | from jax._src.partition_spec import PartitionSpec |
55 | 56 | from jax._src.interpreters import xla |
@@ -627,7 +628,8 @@ def _infer_params_impl( |
627 | 628 | flat_fun, in_type, attr_token, dbg, |
628 | 629 | HashableFunction(res_paths, closure=()), |
629 | 630 | IgnoreKey(ji.inline)) |
630 | | - _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) |
| 631 | + if config.mutable_array_checks.value: |
| 632 | + _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) |
631 | 633 | _attr_update(flat_fun, in_type, attr_token, attrs_tracked) |
632 | 634 |
|
633 | 635 | out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( |
@@ -764,33 +766,9 @@ def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...] |
764 | 766 | " static_argnums or static_argnames parameters of jax.jit." |
765 | 767 | ) from None |
766 | 768 | if config.mutable_array_checks.value: |
767 | | - # TODO(mattjj): make this faster |
768 | | - refs: dict[int, int] = {} |
769 | | - for i, (a, x) in enumerate(zip(avals, explicit_args)): |
770 | | - if (isinstance(a, AbstractRef) and |
771 | | - (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): |
772 | | - raise ValueError( |
773 | | - "only one reference to a mutable array may be passed as an argument " |
774 | | - f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " |
775 | | - f"the mutable array reference of type {a.str_short()} appeared at both " |
776 | | - f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}." |
777 | | - if dbg else |
778 | | - f"at both flat index {dup_idx} and flat index {i}") from None |
| 769 | + _check_no_aliased_ref_args(dbg, avals, explicit_args) |
779 | 770 | return tuple(avals) |
780 | 771 |
|
781 | | -def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None: |
782 | | - if not config.mutable_array_checks.value: return |
783 | | - refs: set[int] = {id(core.get_referent(c)) for c in consts |
784 | | - if isinstance(core.get_aval(c), AbstractRef)} |
785 | | - for i, x in enumerate(args): |
786 | | - if id(core.get_referent(x)) in refs: |
787 | | - a = shaped_abstractify(x) |
788 | | - raise ValueError( |
789 | | - f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable " |
790 | | - f"array reference of type {a.str_short()} was both closed over and " |
791 | | - f"passed as the argument " |
792 | | - f"{dbg.arg_names[i]}" if dbg else "at flat index {i}") |
793 | | - |
794 | 772 | def _extract_implicit_args( |
795 | 773 | in_type: Sequence[tuple[core.AbstractValue, bool]], |
796 | 774 | explicit_args: Sequence[Any] |
|
0 commit comments