@@ -143,21 +143,31 @@ def new_arg(primal_aval, nz):
143143
144144def direct_linearize (traceable , * primals , ** kwargs ):
145145 has_aux = kwargs .pop ('has_aux' , False )
146- assert not has_aux
147146 with core .take_current_trace () as parent_trace :
148147 tangent_trace = pe .DynamicJaxprTrace ()
149148 tangents = [tangent_trace .new_arg (get_aval (p ).to_tangent_aval ()) for p in primals ]
150149 linearize_trace = LinearizeTrace (parent_trace , tangent_trace )
151150 tracers = [LinearizeTracer (linearize_trace , p , t ) for p , t in zip (primals , tangents )]
152151 with core .set_current_trace (linearize_trace ):
153- ans = traceable .call_wrapped (* tracers )
154-
152+ if has_aux :
153+ ans , aux = traceable .call_wrapped (* tracers )
154+ aux_primals = [x .primal
155+ if isinstance (x , LinearizeTracer )
156+ and x ._trace .tag is linearize_trace .tag
157+ else x for x in aux ]
158+ else :
159+ ans = traceable .call_wrapped (* tracers )
160+ aux = None
155161 out_primals , out_tangents = unzip2 (map (linearize_trace .to_primal_tangent_pair , ans ))
162+ out_tangents = map (instantiate_zeros , out_tangents )
156163 out_tangents = map (tangent_trace .to_jaxpr_tracer , out_tangents )
157164 jaxpr , consts , attrs_tracked = tangent_trace .to_jaxpr (out_tangents )
158165 out_tangents_pvals = [pe .PartialVal .unknown (core .get_aval (t )) for t in out_tangents ]
159166 del attrs_tracked # TODO: attrs
160- return out_primals , out_tangents_pvals , jaxpr , consts
167+ if has_aux :
168+ return out_primals , out_tangents_pvals , jaxpr , consts , aux_primals
169+ else :
170+ return out_primals , out_tangents_pvals , jaxpr , consts
161171
162172def linearize (traceable , * primals , ** kwargs ):
163173 if config .use_direct_linearize .value :
@@ -532,22 +542,45 @@ def to_primal_tangent_pair(self, val):
532542
533543 def process_primitive (self , primitive , args , params ):
534544 primals_in , tangents_in = unzip2 (map (self .to_primal_tangent_pair , args ))
535- tangent_nonzeros = [type (t ) is not Zero for t in tangents_in ]
545+ tangent_nzs = [type (t ) is not Zero for t in tangents_in ]
536546 if all (type (t ) is Zero for t in tangents_in ):
537547 return primitive .bind_with_trace (self .parent_trace , primals_in , params )
538- lin = primitive_linearizations .get (primitive )
539- if lin is None :
540- lin = partial (fallback_linearize_rule , primitive )
548+ fallback = partial (fallback_linearize_rule , primitive )
549+ lin = primitive_linearizations .get (primitive , fallback )
541550 with core .set_current_trace (self .parent_trace ):
542- primal_out , tangent_nonzeros_out , residuals , linearized = lin (
543- tangent_nonzeros , * primals_in , ** params )
551+ primal_out , tangent_nzs_out , residuals , linearized = lin (
552+ tangent_nzs , * primals_in , ** params )
544553 with core .set_current_trace (self .tangent_trace ):
545554 tangent_out = linearized (residuals , * tangents_in )
546555 if primitive .multiple_results :
547556 return [maybe_linearize_tracer (self , x , nz , t )
548- for x , nz , t in zip (primal_out , tangent_nonzeros , tangent_out )]
557+ for x , nz , t in zip (primal_out , tangent_nzs_out , tangent_out )]
549558 else :
550- return maybe_linearize_tracer (self , primal_out , tangent_nonzeros , tangent_out )
559+ return maybe_linearize_tracer (self , primal_out , tangent_nzs_out , tangent_out )
560+
561+ def process_custom_vjp_call (self , prim , fun , fwd , bwd , tracers , out_trees ,
562+ symbolic_zeros ):
563+ primals_in , tangents_in = unzip2 (map (self .to_primal_tangent_pair , tracers ))
564+ if all (type (t ) is Zero for t in tangents_in ):
565+ return prim .bind_with_trace (self .parent_trace ,
566+ (fun , fwd , bwd , * primals_in ),
567+ dict (out_trees = out_trees , symbolic_zeros = symbolic_zeros ))
568+ fwd_in = [(p , type (t ) is not Zero ) for p , t in zip (primals_in , tangents_in )]
569+ fwd_in = [x for pair in fwd_in for x in pair ] # flatten
570+ with core .set_current_trace (self .parent_trace ):
571+ res_and_primals_out = fwd .call_wrapped (* fwd_in )
572+
573+ _ , res_tree = out_trees ()
574+ res , primals_out = split_list (res_and_primals_out , [res_tree .num_leaves ])
575+ avals_out = [core .get_aval (x ).to_tangent_aval () for x in primals_out ]
576+
577+ with core .set_current_trace (self .tangent_trace ):
578+ tangents_in = map (instantiate_zeros , tangents_in )
579+ tangents_out = custom_lin_p .bind (
580+ * res , * tangents_in , num_res = res_tree .num_leaves , bwd = bwd ,
581+ out_avals = avals_out , symbolic_zeros = symbolic_zeros )
582+ tangent_nzs_out = [type (t ) is not Zero for t in tangents_out ]
583+ return map (partial (maybe_linearize_tracer , self ), primals_out , tangent_nzs_out , tangents_out )
551584
552585def maybe_linearize_tracer (trace , primal , is_nonzero , tangent ):
553586 if is_nonzero :
@@ -557,21 +590,50 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
557590 assert type (tangent ) is Zero
558591 return primal
559592
560- def fallback_linearize_rule (prim , _ , * args , ** kwargs ):
561- assert not prim .multiple_results
562-
563- def call_prim (* args_ ):
564- return [prim .bind (* args_ , ** kwargs )]
565-
566- with config .use_direct_linearize (False ):
567- (out_primal ,), (out_tangent_pval ,), jaxpr , consts , * _maybe_aux = linearize (
568- lu .wrap_init (call_prim ), * args , ** kwargs )
593+ def fallback_linearize_rule (prim , nonzeros , * primals , ** params ):
594+ jvp = primitive_jvps .get (prim )
595+ if not jvp :
596+ msg = f"Differentiation rule for '{ prim } ' not implemented"
597+ raise NotImplementedError (msg )
598+ current_name_stack = source_info_util .current_name_stack ()
599+ with core .take_current_trace () as parent_trace :
600+ trace = pe .JaxprTrace (parent_trace , current_name_stack , core .TraceTag ())
601+ tangent_avals = [get_aval (p ).to_tangent_aval () for p in primals ]
602+ tangent_args = [trace .new_arg (pe .PartialVal .unknown (aval )) if nz else Zero (aval )
603+ for aval , nz in zip (tangent_avals , nonzeros )]
604+ with core .set_current_trace (trace ):
605+ out_primals , out_tangents = jvp (primals , tangent_args , ** params )
606+
607+ if not prim .multiple_results :
608+ out_primals = [out_primals ]
609+ out_tangents = [out_tangents ]
610+
611+ out_primals = [trace .to_jaxpr_tracer (p ).pval .get_known () for p in out_primals ]
612+ out_nzs = [type (r ) is not Zero for r in out_tangents ]
613+ out_tangent_avals = [get_aval (p ).to_tangent_aval () for p in out_primals ]
614+ out_nz_tracers = [trace .to_jaxpr_tracer (r ) for (r , nz ) in zip (out_tangents , out_nzs ) if nz ]
615+ in_tracers = [t for t in tangent_args if type (t ) is not Zero ]
616+ jaxpr , out_consts , _ = pe .tracers_to_jaxpr (in_tracers , out_nz_tracers )
617+
618+ def linearized (residuals , * tangents ):
619+ nz_tangents_in = [t for (t , nz ) in zip (tangents , nonzeros ) if nz ]
620+ nz_tangents_out = core .eval_jaxpr (jaxpr , residuals , * nz_tangents_in )
621+ nz_tangents_out_iter = iter (nz_tangents_out )
622+ all_out_tangents = [next (nz_tangents_out_iter ) if nz else Zero (aval )
623+ for (aval , nz ) in zip (out_tangent_avals , out_nzs )]
624+ if prim .multiple_results :
625+ return all_out_tangents
626+ else :
627+ out_tangent , = all_out_tangents
628+ return out_tangent
569629
570- def linearized (residuals , * tangents ):
571- out_tangent , = core .eval_jaxpr (jaxpr , residuals , * tangents )
572- return out_tangent
630+ if prim .multiple_results :
631+ return out_primals , out_nzs , out_consts , linearized
632+ else :
633+ out_primal , = out_primals
634+ out_nz , = out_nzs
635+ return out_primal , out_nz , out_consts , linearized
573636
574- return out_primal , True , consts , linearized
575637
576638class LinearizeTracer (Tracer ):
577639 __slots__ = ['primal' , 'tangent' ]
0 commit comments