@@ -483,39 +483,44 @@ def to_primal_tangent_pair(self, val):
483483
484484 def process_primitive (self , primitive , args , params ):
485485 primals_in , tangents_in = unzip2 (map (self .to_primal_tangent_pair , args ))
486+ tangent_nonzeros = [type (t ) is not Zero for t in tangents_in ]
486487 if all (type (t ) is Zero for t in tangents_in ):
487488 return primitive .bind_with_trace (self .parent_trace , primals_in , params )
488489 lin = primitive_linearizations .get (primitive )
489490 if lin is None :
490491 lin = partial (fallback_linearize_rule , primitive )
491492 with core .set_current_trace (self .parent_trace ):
492- primal_out , linearized = lin (* primals_in , ** params )
493+ primal_out , tangent_nonzeros_out , residuals , linearized = lin (
494+ tangent_nonzeros , * primals_in , ** params )
493495 with core .set_current_trace (self .tangent_trace ):
494- tangent_out = linearized (* tangents_in )
496+ tangent_out = linearized (residuals , * tangents_in )
495497 if primitive .multiple_results :
496- return [maybe_linearize_tracer (self , x , t ) for x , t in zip (primal_out , tangent_out )]
498+ return [maybe_linearize_tracer (self , x , nz , t )
499+ for x , nz , t in zip (primal_out , tangent_nonzeros , tangent_out )]
497500 else :
498- return maybe_linearize_tracer (self , primal_out , tangent_out )
501+ return maybe_linearize_tracer (self , primal_out , tangent_nonzeros , tangent_out )
499502
500- def maybe_linearize_tracer (trace , primal , tangent ):
501- if type (tangent ) is Zero :
502- return primal
503- else :
503+ def maybe_linearize_tracer (trace , primal , is_nonzero , tangent ):
504+ if is_nonzero :
505+ assert not type (tangent ) is Zero
504506 return LinearizeTracer (trace , primal , tangent )
507+ else :
508+ assert type (tangent ) is Zero
509+ return primal
505510
506- def fallback_linearize_rule (prim , * args , ** kwargs ):
511+ def fallback_linearize_rule (prim , _ , * args , ** kwargs ):
507512 def call_prim (* args_ ):
508513 return prim .bind (* args_ , ** kwargs )
509514 with config .use_direct_linearize (False ):
510515 out_primals , out_tangents_pvals , jaxpr , consts , * _maybe_aux = linearize (
511516 lu .wrap_init (call_prim ), * args , ** kwargs )
512- def linearized (* tangents ):
513- tangents_out = iter (core .eval_jaxpr (jaxpr , consts , * tangents ))
517+ def linearized (residuals , * tangents ):
518+ tangents_out = iter (core .eval_jaxpr (jaxpr , residuals , * tangents ))
514519 full_out = [pval .get_known () if pval .is_known () else next (tangents_out )
515520 for pval in out_tangents_pvals ]
516521 assert next (tangents_out , None ) is None
517522 return full_out
518- return out_primals , linearized
523+ return out_primals , [ True for _ in out_primals ], consts , linearized
519524
520525class LinearizeTracer (Tracer ):
521526 __slots__ = ['primal' , 'tangent' ]
@@ -547,7 +552,7 @@ def to_concrete_value(self):
547552primitive_transposes : dict [core .Primitive , Callable ] = {}
548553# transpose rules that internally perform reductions over the given named axes
549554reducing_transposes : dict [core .Primitive , Callable ] = {}
550- primitive_linearizations : dict [core .Primitive , Callable ] = {}
555+ primitive_linearizations : dict [core .Primitive , Callable ] = {}
551556
552557def deflinear (primitive , transpose_rule ):
553558 primitive_jvps [primitive ] = partial (linear_jvp , primitive )
0 commit comments