3030from jax ._src import source_info_util
3131from jax ._src .ad_util import (
3232 add_jaxvals , replace_internal_symbolic_zeros ,
33- replace_rule_output_symbolic_zeros , Zero , zeros_like_aval )
33+ replace_rule_output_symbolic_zeros , Zero , zeros_like_aval , SymbolicZero )
3434from jax ._src .ad_util import zeros_like_p , add_jaxvals_p # noqa: F401
3535from jax ._src .api_util import flatten_fun , flatten_fun_nokwargs
3636from jax ._src .core import (Trace , Tracer , get_aval , call_p , Primitive , Literal )
@@ -157,6 +157,7 @@ def new_arg(primal_aval, nz):
157157 if attrs_tracked :
158158 raise NotImplementedError ("TODO: attrs" )
159159 residuals_and_primals = (* tangent_consts , * out_primals )
160+ residuals_and_primals = map (primal_trace .to_jaxpr_tracer , residuals_and_primals )
160161 primal_jaxpr , primal_consts , attrs_tracked = primal_trace .to_jaxpr (residuals_and_primals )
161162 num_residuals = len (tangent_consts )
162163 tangent_jaxpr = pe .close_jaxpr (convert_constvars_jaxpr_constvars_at_end (tangent_jaxpr ))
@@ -168,8 +169,10 @@ def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None):
168169 with core .take_current_trace () as parent_trace :
169170 tangent_trace = pe .DynamicJaxprTrace ()
170171 tangents = [tangent_trace .new_arg (get_aval (p ).to_tangent_aval ()) for p in primals ]
172+ tangents = [Zero .from_primal_value (t ) if dtype (t ) == float0 else t for t in tangents ]
171173 linearize_trace = LinearizeTrace (parent_trace , tangent_trace , tag = tag )
172174 tracers = [LinearizeTracer (linearize_trace , p , t ) for p , t in zip (primals , tangents )]
175+ tracers = [t .full_lower () for t in tracers ]
173176 with core .set_current_trace (linearize_trace ):
174177 if has_aux :
175178 ans , aux = traceable .call_wrapped (* tracers )
@@ -586,20 +589,18 @@ def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros):
586589 if all (type (t ) is Zero for t in tangents_in ):
587590 return prim .bind_with_trace (self .parent_trace , (fun , f_jvp , * primals_in ),
588591 dict (symbolic_zeros = symbolic_zeros ))
589- with core .set_current_trace (self .parent_trace ):
590- if not symbolic_zeros :
591- tangents_in = map (instantiate_zeros , tangents_in )
592- else :
593- tangents_in = map (replace_internal_symbolic_zeros , tangents_in )
594- nonzeros_in = [type (t ) is not Zero for t in tangents_in ]
595592
596593 def _f_jvp (primals , tangents ):
597594 outs = f_jvp .call_wrapped (* primals , * tangents )
598595 primals_out , tangents_out = split_list (outs , [len (outs ) // 2 ])
599596 return primals_out , tangents_out
600597
601- primals_out , tangent_nzs_out , residuals , linearized = linearize_from_jvp (
602- _f_jvp , True , nonzeros_in , primals_in , {})
598+ with core .set_current_trace (self .parent_trace ):
599+ instantiate_zeros = not symbolic_zeros
600+ nonzeros_in = [type (t ) is not Zero for t in tangents_in ]
601+ primals_out , tangent_nzs_out , residuals , linearized = linearize_from_jvp (
602+ _f_jvp , True , nonzeros_in , symbolic_zeros , instantiate_zeros , primals_in , {})
603+
603604 with core .set_current_trace (self .tangent_trace ):
604605 tangents_out = linearized (residuals , * tangents_in )
605606 tangents_out = map (replace_rule_output_symbolic_zeros , tangents_out )
@@ -622,8 +623,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
622623 res , primals_out = split_list (res_and_primals_out , [res_tree .num_leaves ])
623624 avals_out = [core .get_aval (x ).to_tangent_aval () for x in primals_out ]
624625
626+ tangents_in = map (instantiate_zeros , tangents_in )
625627 with core .set_current_trace (self .tangent_trace ):
626- tangents_in = map (instantiate_zeros , tangents_in )
627628 tangents_out = custom_lin_p .bind (
628629 * res , * tangents_in , num_res = res_tree .num_leaves , bwd = bwd ,
629630 out_avals = avals_out , symbolic_zeros = symbolic_zeros )
@@ -666,14 +667,29 @@ def fallback_linearize_rule(_prim, _nonzeros, *primals, **params):
666667 if not jvp :
667668 msg = f"Differentiation rule for '{ _prim } ' not implemented"
668669 raise NotImplementedError (msg )
669- return linearize_from_jvp (jvp , _prim .multiple_results , _nonzeros , primals , params )
670+ return linearize_from_jvp (jvp , _prim .multiple_results , _nonzeros , False , False , primals , params )
670671
671- def linearize_from_jvp (jvp , multiple_results , nonzeros , primals , params ):
672+ def linearize_from_jvp (jvp , multiple_results , nonzeros ,
673+ user_facing_symbolic_zeros , instantiate_input_zeros , primals , params ):
672674 current_name_stack = source_info_util .current_name_stack ()
673675 with core .take_current_trace () as parent_trace :
674676 trace = pe .JaxprTrace (parent_trace , current_name_stack , core .TraceTag ())
675677 tangent_avals = [get_aval (p ).to_tangent_aval () for p in primals ]
676- tangent_args = [trace .new_arg (pe .PartialVal .unknown (aval )) if nz else Zero (aval )
678+
679+ def make_zero (aval ):
680+ if instantiate_input_zeros :
681+ return zeros_like_aval (aval )
682+ elif user_facing_symbolic_zeros :
683+ return SymbolicZero (aval )
684+ else :
685+ return Zero (aval )
686+
687+ if user_facing_symbolic_zeros :
688+ zero_type = SymbolicZero
689+ else :
690+ zero_type = Zero
691+
692+ tangent_args = [trace .new_arg (pe .PartialVal .unknown (aval )) if nz else make_zero (aval )
677693 for aval , nz in zip (tangent_avals , nonzeros )]
678694 with core .set_current_trace (trace ):
679695 out_primals , out_tangents = jvp (primals , tangent_args , ** params )
@@ -683,10 +699,11 @@ def linearize_from_jvp(jvp, multiple_results, nonzeros, primals, params):
683699 out_tangents = [out_tangents ]
684700
685701 out_primals = [trace .to_jaxpr_tracer (p ).pval .get_known () for p in out_primals ]
686- out_nzs = [type (r ) is not Zero for r in out_tangents ]
702+ out_nzs = [type (r ) is not zero_type for r in out_tangents ]
687703 out_tangent_avals = [get_aval (p ).to_tangent_aval () for p in out_primals ]
688- out_nz_tracers = [trace .to_jaxpr_tracer (r ) for (r , nz ) in zip (out_tangents , out_nzs ) if nz ]
689- in_tracers = [t for t in tangent_args if type (t ) is not Zero ]
704+ out_nz_tracers = [trace .instantiate_const (trace .to_jaxpr_tracer (r ))
705+ for (r , nz ) in zip (out_tangents , out_nzs ) if nz ]
706+ in_tracers = [t for t , nz in zip (tangent_args , nonzeros ) if nz ]
690707 jaxpr , out_consts , _ = pe .tracers_to_jaxpr (in_tracers , out_nz_tracers )
691708
692709 def linearized (residuals , * tangents ):
0 commit comments