@@ -556,17 +556,14 @@ def _infer_params_impl(
556556 "pjit does not support kwargs when in_shardings is specified." )
557557
558558 if pjit_mesh is not None :
559- jit_name = 'pjit'
560559 if (ji .backend or ji .device ) and not pjit_mesh .empty :
561560 raise ValueError (
562561 "Mesh context manager should not be used with jit when backend or "
563562 "device is also specified as an argument to jit." )
564- else :
565- jit_name = 'jit'
566563
567564 axes_specs = _flat_axes_specs (ji .abstracted_axes , * args , ** kwargs )
568565
569- dbg = debug_info (jit_name , ji .fun_sourceinfo , ji .fun_signature , args , kwargs ,
566+ dbg = debug_info ('jit' , ji .fun_sourceinfo , ji .fun_signature , args , kwargs ,
570567 ji .static_argnums , ji .static_argnames )
571568 f = lu .wrap_init (fun )
572569 f , res_paths = result_paths (f )
@@ -593,6 +590,7 @@ def _infer_params_impl(
593590 in_shardings_leaves = out_shardings_leaves = tuple (leaves )
594591 in_shardings_treedef = out_shardings_treedef = treedef
595592 else :
593+ jit_name = 'pjit' if pjit_mesh is not None else 'jit'
596594 in_shardings_leaves = tuple (
597595 _create_sharding_for_array (pjit_mesh , x , 'in_shardings' , jit_name )
598596 for x in ji .in_shardings_leaves )
@@ -607,35 +605,12 @@ def _infer_params_impl(
607605
608606 in_type : core .InputType | tuple [core .AbstractValue , ...]
609607 if config .dynamic_shapes .value :
608+ assert in_avals is None
610609 in_type = pe .infer_lambda_input_type (axes_specs , explicit_args )
611610 in_avals = tuple (a for a , e in in_type if e )
612- elif in_avals is None :
613- avals = []
614- for i , a in enumerate (explicit_args ):
615- try :
616- avals .append (shaped_abstractify (a ))
617- except OverflowError as e :
618- arg_path = (f"argument path is { dbg .arg_names [i ]} " if dbg
619- else f"flattened argument number is { i } " )
620- raise OverflowError (
621- "An overflow was encountered while parsing an argument to a jitted "
622- f"computation, whose { arg_path } ."
623- ) from e
624- except TypeError as e :
625- arg_description = (f"path { dbg .arg_names [i ]} " if dbg
626- else f"flattened argument number { i } " )
627- raise TypeError (
628- f"Error interpreting argument to { fun } as an abstract array."
629- f" The problematic value is of type { type (a )} and was passed to"
630- f" the function at { arg_description } .\n "
631- "This typically means that a jit-wrapped function was called with a non-array"
632- " argument, and this argument was not marked as static using the"
633- " static_argnums or static_argnames parameters of jax.jit."
634- ) from e
635-
636- in_type = in_avals = tuple (avals )
637611 else :
638- in_type = in_avals
612+ in_type = in_avals # type: ignore
613+ assert in_avals is not None
639614
640615 in_shardings_flat , in_layouts_flat = _process_in_axis_resources (
641616 in_shardings_treedef , in_shardings_leaves ,
@@ -652,6 +627,7 @@ def _infer_params_impl(
652627 flat_fun , in_type , attr_token , dbg ,
653628 HashableFunction (res_paths , closure = ()),
654629 IgnoreKey (ji .inline ))
630+ _check_no_aliased_closed_over_refs (dbg , (* jaxpr .consts , * consts ), explicit_args )
655631 _attr_update (flat_fun , in_type , attr_token , attrs_tracked )
656632
657633 out_shardings_flat , out_layouts_flat = _check_and_canonicalize_out_shardings (
@@ -693,7 +669,6 @@ def _infer_params_impl(
693669 donated_invars , dbg .arg_names if dbg else None , len (consts ),
694670 attrs_tracked , abstract_mesh ), args_flat
695671
696-
697672def get_abstract_mesh_from_avals (in_avals ):
698673 if not config .sharding_in_types .value :
699674 return None
@@ -711,9 +686,7 @@ def get_abstract_mesh_from_avals(in_avals):
711686class InferParamsCacheEntry :
712687 """Mutable value object for _infer_params_cached."""
713688 __slots__ = ['pjit_params' ]
714-
715689 pjit_params : PjitParams | None
716-
717690 def __init__ (self ):
718691 self .pjit_params = None
719692
@@ -747,34 +720,76 @@ def _infer_params(
747720 resource_env = None
748721 pjit_mesh = None
749722
750- skip_cache = config .dynamic_shapes .value
751- if not skip_cache :
752- signature , dynargs = jax_jit .parse_arguments (
753- args , tuple (kwargs .values ()), tuple (kwargs .keys ()), ji .static_argnums ,
754- ji .static_argnames , tree_util .default_registry )
755- try :
756- avals = tuple (shaped_abstractify (a ) for a in dynargs )
757- except (OverflowError , TypeError ):
758- # If we see something we don't understand, use the slow path.
759- skip_cache = True
760-
761- if skip_cache :
723+ if config .dynamic_shapes .value : # if dynamic shapes, don't use the cache
762724 p , args_flat = _infer_params_impl (fun , ji , pjit_mesh , resource_env , args ,
763725 kwargs , in_avals = None )
764726 return p , p .consts + args_flat
765727
766- entry = _infer_params_cached (
767- fun , ji , signature , avals , pjit_mesh , resource_env )
728+ signature , dynargs = jax_jit .parse_arguments (
729+ args , tuple (kwargs .values ()), tuple (kwargs .keys ()), ji .static_argnums ,
730+ ji .static_argnames , tree_util .default_registry )
731+ dbg = debug_info ('jit' , ji .fun_sourceinfo , ji .fun_signature , args , kwargs ,
732+ ji .static_argnums , ji .static_argnames )
733+ avals = _infer_input_type (fun , dbg , dynargs )
734+ entry = _infer_params_cached (fun , ji , signature , avals , pjit_mesh , resource_env )
768735 if entry .pjit_params is None :
769736 p , args_flat = _infer_params_impl (
770737 fun , ji , pjit_mesh , resource_env , args , kwargs , in_avals = avals )
771- if p .attrs_tracked :
772- # If there are attrs_tracked, don't use the cache.
738+ if p .attrs_tracked : # if attrs, don't popoulate the cache
773739 return p , p .consts + args_flat
774- else :
775- entry .pjit_params = p
740+ entry .pjit_params = p
776741 return entry .pjit_params , entry .pjit_params .consts + dynargs
777742
743+ def _infer_input_type (fun , dbg , explicit_args ) -> tuple [core .AbstractValue , ...]:
744+ avals = []
745+ try :
746+ for i , x in enumerate (explicit_args ):
747+ avals .append (shaped_abstractify (x ))
748+ except OverflowError :
749+ arg_path = (f"argument path is { dbg .arg_names [i ]} " if dbg # type: ignore
750+ else f"flattened argument number is { i } " ) # type: ignore
751+ raise OverflowError (
752+ "An overflow was encountered while parsing an argument to a jitted "
753+ f"computation, whose { arg_path } ."
754+ ) from None
755+ except TypeError :
756+ arg_description = (f"path { dbg .arg_names [i ]} " if dbg # type: ignore
757+ else f"flattened argument number { i } " ) # type: ignore
758+ raise TypeError (
759+ f"Error interpreting argument to { fun } as an abstract array."
760+ f" The problematic value is of type { type (x )} and was passed to" # type: ignore
761+ f" the function at { arg_description } .\n "
762+ "This typically means that a jit-wrapped function was called with a non-array"
763+ " argument, and this argument was not marked as static using the"
764+ " static_argnums or static_argnames parameters of jax.jit."
765+ ) from None
766+ if config .mutable_array_checks .value :
767+ # TODO(mattjj): make this faster
768+ refs : dict [int , int ] = {}
769+ for i , (a , x ) in enumerate (zip (avals , explicit_args )):
770+ if (isinstance (a , AbstractRef ) and
771+ (dup_idx := refs .setdefault (id (core .get_referent (x )), i )) != i ):
772+ raise ValueError (
773+ "only one reference to a mutable array may be passed as an argument "
774+ f"to a function, but when tracing { dbg .func_src_info } for { dbg .traced_for } "
775+ f"the mutable array reference of type { a .str_short ()} appeared at both "
776+ f"{ dbg .arg_names [dup_idx ]} and { dbg .arg_names [i ]} ."
777+ if dbg else
778+ f"at both flat index { dup_idx } and flat index { i } " ) from None
779+ return tuple (avals )
780+
781+ def _check_no_aliased_closed_over_refs (dbg , consts , args ) -> None :
782+ if not config .mutable_array_checks .value : return
783+ refs : set [int ] = {id (core .get_referent (c )) for c in consts
784+ if isinstance (core .get_aval (c ), AbstractRef )}
785+ for i , x in enumerate (args ):
786+ if id (core .get_referent (x )) in refs :
787+ a = shaped_abstractify (x )
788+ raise ValueError (
789+ f"when tracing { dbg .func_src_info } for { dbg .traced_for } , a mutable "
790+ f"array reference of type { a .str_short ()} was both closed over and "
791+ f"passed as the argument "
792+ f"{ dbg .arg_names [i ]} " if dbg else "at flat index {i}" )
778793
779794def _extract_implicit_args (
780795 in_type : Sequence [tuple [core .AbstractValue , bool ]],
0 commit comments