@@ -2076,14 +2076,22 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
20762076 donated_invars , ctx_mesh , name , keep_unused , inline ,
20772077 compiler_options_kvs ):
20782078 primal_jaxpr , num_residuals , nzs_out , tangent_jaxpr = ad .linearize_jaxpr (jaxpr , nzs )
2079- # constvars will become residuals. Move them to the end of the ordinary args.
20802079 res_shardings = (UNSPECIFIED ,) * num_residuals
20812080 res_layouts = (None ,) * num_residuals
20822081 res_donated = (False ,) * num_residuals
2082+
2083+ in_fwd = pe ._jaxpr_forwarding (primal_jaxpr .jaxpr )
2084+ in_fwd , _ = split_list (in_fwd , [num_residuals ])
2085+ keep = tuple (f is None for f in in_fwd ) + (True ,) * len (out_shardings )
2086+ primal_jaxpr = pe .prune_closed_jaxpr_outputs (primal_jaxpr , keep )
2087+ num_residuals = sum (f is None for f in in_fwd )
2088+
20832089 def tangent_fun (consts_ , * tangents ):
2090+ consts_it = iter (consts_ )
2091+ res = [next (consts_it ) if f is None else primals_in [f ] for f in in_fwd ]
2092+ assert next (consts_it , None ) is None
20842093 tangents_nz = _filter_zeros (nzs , tangents )
2085- assert len (consts_ ) == num_residuals
2086- nz_tangents_out = pjit_p .bind (* (* tangents_nz , * consts_ ),
2094+ nz_tangents_out = pjit_p .bind (* (* tangents_nz , * res ),
20872095 jaxpr = tangent_jaxpr ,
20882096 in_shardings = _filter_zeros (nzs , in_shardings ) + res_shardings ,
20892097 out_shardings = _filter_zeros (nzs_out , out_shardings ),
@@ -2106,9 +2114,9 @@ def _filter_zeros(is_nz_l, l):
21062114
21072115 ans = pjit_p .bind (* primals_in , jaxpr = primal_jaxpr ,
21082116 in_shardings = in_shardings ,
2109- out_shardings = (* res_shardings , * out_shardings ),
2117+ out_shardings = (* res_shardings [: num_residuals ] , * out_shardings ),
21102118 in_layouts = in_layouts ,
2111- out_layouts = (* res_layouts , * out_layouts ),
2119+ out_layouts = (* res_layouts [: num_residuals ] , * out_layouts ),
21122120 donated_invars = donated_invars ,
21132121 ctx_mesh = ctx_mesh ,
21142122 name = name ,
0 commit comments