@@ -105,22 +105,56 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
105105 store .store (aux_primals )
106106 return out_primals , out_tangents
107107
108+ def convert_constvars_jaxpr_constvars_at_end (jaxpr : core .Jaxpr ) -> core .Jaxpr :
109+ dbg = jaxpr .debug_info and jaxpr .debug_info ._replace (
110+ arg_names = jaxpr .debug_info .arg_names + (None ,) * len (jaxpr .constvars ))
111+ return core .Jaxpr (constvars = (),
112+ invars = jaxpr .invars + jaxpr .constvars ,
113+ outvars = jaxpr .outvars , eqns = jaxpr .eqns ,
114+ effects = jaxpr .effects , debug_info = dbg )
115+
116+ def linearize_jaxpr (jaxpr , nonzeros ):
117+ primal_trace = pe .DynamicJaxprTrace ()
118+ tangent_trace = pe .DynamicJaxprTrace ()
119+ lin_trace = LinearizeTrace (primal_trace , tangent_trace )
120+
121+ def new_arg (primal_aval , nz ):
122+ primal = primal_trace .new_arg (primal_aval )
123+ tangent_aval = primal_aval .to_tangent_aval ()
124+ tangent = tangent_trace .new_arg (tangent_aval ) if nz else Zero (tangent_aval )
125+ return LinearizeTracer (lin_trace , primal , tangent )
126+
127+ tracers = [new_arg (v .aval , nz ) for (v , nz ) in zip (jaxpr .jaxpr .invars , nonzeros )]
128+ with core .set_current_trace (lin_trace ):
129+ ans = core .eval_jaxpr (jaxpr .jaxpr , jaxpr .consts , * tracers )
130+
131+ out_primals , out_tangents = unzip2 (map (lin_trace .to_primal_tangent_pair , ans ))
132+ nzs_out = [type (t ) is not Zero for t in out_tangents ]
133+ out_tangents = [tangent_trace .to_jaxpr_tracer (t )
134+ for (nz , t ) in zip (nzs_out , out_tangents ) if nz ]
135+ tangent_jaxpr , tangent_consts , attrs_tracked = tangent_trace .to_jaxpr (out_tangents )
136+ del attrs_tracked # TODO: attrs
137+ residuals_and_primals = (* tangent_consts , * out_primals )
138+ primal_jaxpr , primal_consts , attrs_tracked = primal_trace .to_jaxpr (residuals_and_primals )
139+ num_residuals = len (tangent_consts )
140+ tangent_jaxpr = pe .close_jaxpr (convert_constvars_jaxpr_constvars_at_end (tangent_jaxpr ))
141+ del attrs_tracked # TODO: attrs
142+ return core .ClosedJaxpr (primal_jaxpr , primal_consts ), num_residuals , nzs_out , tangent_jaxpr
143+
108144def direct_linearize (traceable , * primals , ** kwargs ):
109145 has_aux = kwargs .pop ('has_aux' , False )
110146 assert not has_aux
111147 with core .take_current_trace () as parent_trace :
112- frame = pe .JaxprStackFrame ()
113- tangent_trace = pe .DynamicJaxprTrace (frame )
148+ tangent_trace = pe .DynamicJaxprTrace ()
114149 tangents = [tangent_trace .new_arg (get_aval (p ).to_tangent_aval ()) for p in primals ]
115- tag = core .TraceTag ()
116- linearize_trace = LinearizeTrace (parent_trace , tangent_trace , tag )
150+ linearize_trace = LinearizeTrace (parent_trace , tangent_trace )
117151 tracers = [LinearizeTracer (linearize_trace , p , t ) for p , t in zip (primals , tangents )]
118152 with core .set_current_trace (linearize_trace ):
119153 ans = traceable .call_wrapped (* tracers )
120154
121155 out_primals , out_tangents = unzip2 (map (linearize_trace .to_primal_tangent_pair , ans ))
122156 out_tangents = map (tangent_trace .to_jaxpr_tracer , out_tangents )
123- jaxpr , consts , attrs_tracked = frame .to_jaxpr (tangent_trace , out_tangents )
157+ jaxpr , consts , attrs_tracked = tangent_trace .to_jaxpr (out_tangents )
124158 out_tangents_pvals = [pe .PartialVal .unknown (core .get_aval (t )) for t in out_tangents ]
125159 del attrs_tracked # TODO: attrs
126160 return out_primals , out_tangents_pvals , jaxpr , consts
@@ -469,8 +503,8 @@ def _primal_tangent_shapes_match(primal, tangent):
469503
470504class LinearizeTrace (Trace ):
471505
472- def __init__ (self , parent_trace , tangent_trace , tag ):
473- self .tag = tag
506+ def __init__ (self , parent_trace , tangent_trace , tag = None ):
507+ self .tag = core . TraceTag () if tag is None else tag
474508 self .parent_trace = parent_trace
475509 self .tangent_trace = tangent_trace
476510
@@ -509,18 +543,20 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
509543 return primal
510544
511545def fallback_linearize_rule (prim , _ , * args , ** kwargs ):
546+ assert not prim .multiple_results
547+
512548 def call_prim (* args_ ):
513- return prim .bind (* args_ , ** kwargs )
549+ return [prim .bind (* args_ , ** kwargs )]
550+
514551 with config .use_direct_linearize (False ):
515- out_primals , out_tangents_pvals , jaxpr , consts , * _maybe_aux = linearize (
552+ ( out_primal ,), ( out_tangent_pval ,) , jaxpr , consts , * _maybe_aux = linearize (
516553 lu .wrap_init (call_prim ), * args , ** kwargs )
554+
517555 def linearized (residuals , * tangents ):
518- tangents_out = iter (core .eval_jaxpr (jaxpr , residuals , * tangents ))
519- full_out = [pval .get_known () if pval .is_known () else next (tangents_out )
520- for pval in out_tangents_pvals ]
521- assert next (tangents_out , None ) is None
522- return full_out
523- return out_primals , [True for _ in out_primals ], consts , linearized
556+ out_tangent , = core .eval_jaxpr (jaxpr , residuals , * tangents )
557+ return out_tangent
558+
559+ return out_primal , True , consts , linearized
524560
525561class LinearizeTracer (Tracer ):
526562 __slots__ = ['primal' , 'tangent' ]
0 commit comments