@@ -833,7 +833,7 @@ def new_body_f(*c_consts_and_vals):
833833 # This checks if the next cond application will error
834834 _ = cond_f (* c_consts , * out )
835835 return out
836- new_body_f_ = lu .wrap_init (new_body_f )
836+ new_body_f_ = lu .wrap_init (new_body_f , debug_info = body_jaxpr . jaxpr . debug_info )
837837 c_consts_avals = cond_jaxpr .in_avals [:c_consts_num ]
838838 jaxpr , _ , (), () = pe .trace_to_jaxpr_dynamic (new_body_f_ , [* c_consts_avals ,
839839 * body_jaxpr .in_avals ])
@@ -952,7 +952,8 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
952952
953953
954954def shard_map_error_check (
955- error , enabled_errors , * vals_in , jaxpr , in_names , out_names , ** kwargs
955+ error : Error , enabled_errors , * vals_in ,
956+ jaxpr : core .Jaxpr , in_names , out_names , ** kwargs
956957):
957958 if (mesh := kwargs .get ('mesh' )) is None :
958959 raise ValueError ('Mesh must be provided for shard_map with checkify.' )
@@ -976,7 +977,6 @@ def shard_map_error_check(
976977 )
977978 num_out_error_vals = out_tree .num_leaves - len (out_names )
978979
979- @lu .wrap_init
980980 def expand_errors_leading_dim (* xs ):
981981 outs = core .eval_jaxpr (checked_jaxpr .jaxpr , checked_jaxpr .consts , * xs )
982982 errs , outs = split_list (outs , [num_out_error_vals ])
@@ -985,15 +985,18 @@ def expand_errors_leading_dim(*xs):
985985
986986 with core .extend_axis_env_nd (mesh .shape .items ()):
987987 jaxpr , _ , consts , () = pe .trace_to_jaxpr_dynamic (
988- expand_errors_leading_dim , checked_jaxpr .in_avals
988+ lu .wrap_init (expand_errors_leading_dim ,
989+ debug_info = checked_jaxpr .jaxpr .debug_info ),
990+ checked_jaxpr .in_avals
989991 )
990992 checked_jaxpr = core .ClosedJaxpr (jaxpr , consts )
991993
992994 # Update shard_map params to account for extra error values.
993995 # Use fully sharded partitioning for out errors.
994996 new_out_names = (* ([{0 : mesh .axis_names }] * num_out_error_vals ), * out_names )
995997 subfun = lu .hashable_partial (
996- lu .wrap_init (core .eval_jaxpr ), checked_jaxpr .jaxpr , checked_jaxpr .consts
998+ lu .wrap_init (core .eval_jaxpr , debug_info = checked_jaxpr .jaxpr .debug_info ),
999+ checked_jaxpr .jaxpr , checked_jaxpr .consts
9971000 )
9981001 new_params = dict (
9991002 jaxpr = checked_jaxpr .jaxpr ,
@@ -1007,8 +1010,10 @@ def expand_errors_leading_dim(*xs):
10071010 return tree_unflatten (out_tree , err_and_out )
10081011error_checks [shard_map .shard_map_p ] = shard_map_error_check
10091012
1010- def custom_jvp_call_rule (in_err , enabled_errors , * in_vals , num_consts ,
1011- jvp_jaxpr_thunk , call_jaxpr , ** params ):
1013+ def custom_jvp_call_rule (in_err : Error ,
1014+ enabled_errors : set , * in_vals , num_consts ,
1015+ jvp_jaxpr_fun : lu .WrappedFun ,
1016+ call_jaxpr : core .ClosedJaxpr , ** params ):
10121017 # The types to have in mind are:
10131018 # jvp : (a -> b) -> (a, T a) -> (b, T b)
10141019 # checkify : (a -> b) -> a -> Err b
@@ -1021,10 +1026,11 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
10211026 err_vals , err_tree = jtu .tree_flatten (in_err )
10221027 partial_checkify = lu .wrap_init (
10231028 functools .partial (checkify_jaxpr_flat , call_jaxpr .jaxpr ,
1024- call_jaxpr .consts , enabled_errors , err_tree ))
1029+ call_jaxpr .consts , enabled_errors , err_tree ),
1030+ debug_info = call_jaxpr .jaxpr .debug_info )
10251031 partial_checkify , f_metadata = _flatten_and_get_error_metadata_thunk (
10261032 partial_checkify )
1027- jvp = lift_jvp (err_tree .num_leaves , num_consts , jvp_jaxpr_thunk )
1033+ jvp = lift_jvp (err_tree .num_leaves , num_consts , jvp_jaxpr_fun )
10281034 jvp , jvp_out_tree = flatten_fun_output (jvp )
10291035 all_outs = custom_derivatives .custom_jvp_call_p .bind (
10301036 partial_checkify , jvp , * err_vals , * in_vals , ** params )
@@ -1041,17 +1047,17 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
10411047
10421048# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
10431049# outputs that checkify adds (just forwarding the error data's primal and
1044- # tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
1050+ # tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
10451051# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
10461052# Adding another layer of lu.transformation was tricky, though maybe doable.
1047- def lift_jvp (num_errs , num_consts , jvp_jaxpr_thunk ):
1048- @ lu .wrap_init
1053+ def lift_jvp (num_errs : int , num_consts : int ,
1054+ jvp_jaxpr_fun : lu .WrappedFun ) -> lu . WrappedFun :
10491055 def jvp (* xs ):
10501056 n , ragged = divmod (len (xs ), 2 )
10511057 assert not ragged
10521058 primals , tangents = xs [num_consts + num_errs :n ], xs [n + num_consts + num_errs :]
10531059 zeros = [type (t ) is SymbolicZero for t in tangents ]
1054- jvp_jaxpr , jvp_consts , out_zeros = jvp_jaxpr_thunk (* zeros )
1060+ jvp_jaxpr , jvp_consts , out_zeros = jvp_jaxpr_fun . call_wrapped (* zeros )
10551061 nonzero_tangents = [t for t in tangents if type (t ) is not SymbolicZero ]
10561062 out = core .eval_jaxpr (jvp_jaxpr , jvp_consts , * primals , * nonzero_tangents )
10571063 out_primals , nz_out_tangents = split_list (out , [len (out_zeros )])
@@ -1063,7 +1069,7 @@ def jvp(*xs):
10631069 primal_errs = xs [num_consts :num_consts + num_errs ]
10641070 tangent_errs = xs [n + num_consts :n + num_consts + num_errs ]
10651071 return [* primal_errs , * out_primals , * tangent_errs , * out_tangents ]
1066- return jvp
1072+ return lu . wrap_init ( jvp , debug_info = jvp_jaxpr_fun . debug_info )
10671073
10681074def custom_vjp_call_jaxpr_rule (in_err , enabled_errors , * in_vals ,
10691075 fun_jaxpr : core .ClosedJaxpr ,
0 commit comments