@@ -126,6 +126,7 @@ function MAP_joint(
126126 ϕtol = nothing ,
127127 progress:: Bool = true ,
128128 verbosity = (0 ,0 ),
129+ linesearch = OptimKit. HagerZhangLineSearch (verbosity= verbosity[2 ], maxiter= 5 ),
129130 conjgrad_kwargs = (tol= 1e-1 ,nsteps= 500 ),
130131 quasi_sample = false ,
131132 preconditioner = :diag ,
@@ -163,8 +164,11 @@ function MAP_joint(
163164 push! (history, select ((;f,f°,ϕ,∇ϕ_lnP= nothing ,χ²= nothing ,lnP= nothing ,argmaxf_lnP_history), history_keys))
164165
165166 # objective function (with gradient) to maximize
166- @⌛ function objective (ϕ)
167- @⌛ (sum (unbatch (- 2 lnP (:mix ,f°,ϕ,dsθ)))), @⌛ (gradient (ϕ-> - 2 lnP (:mix ,f°,ϕ,dsθ), ϕ)[1 ])
167+ @⌛ function objective (ϕ; need_gradient= true )
168+ (
169+ @⌛ (sum (unbatch (- 2 lnP (:mix ,f°,ϕ,dsθ)))),
170+ need_gradient ? @⌛ (gradient (ϕ-> - 2 lnP (:mix ,f°,ϕ,dsθ), ϕ)[1 ]) : nothing
171+ )
168172 end
169173 # function to compute after each optimization iteration, which
170174 # recomputes the best-fit f given the current ϕ
@@ -182,7 +186,9 @@ function MAP_joint(
182186 ∇ϕ_lnP .= @⌛ gradient (ϕ-> - 2 lnP (:mix ,f°,ϕ,dsθ), ϕ)[1 ]
183187 χ²s = @⌛ - 2 lnP (:mix ,f°,ϕ,dsθ)
184188 χ² = sum (unbatch (χ²s))
185- next! (pbar, showvalues= [(" step" ,i), (" χ²" ,χ²s), (" Ncg" ,length (argmaxf_lnP_history))])
189+ values = [(" step" ,i), (" χ²" ,χ²s), (" Ncg" ,length (argmaxf_lnP_history))]
190+ hasproperty (linesearch, :α ) && push! (values, (" α" , linesearch. α))
191+ next! (pbar, showvalues= values)
186192 push! (history, select ((;f,f°,ϕ,∇ϕ_lnP,χ²,lnP= - χ²/ 2 ,argmaxf_lnP_history), history_keys))
187193 if (
188194 ! isnothing (ϕtol) &&
@@ -204,7 +210,7 @@ function MAP_joint(
204210 lbfgs_rank;
205211 maxiter = nsteps,
206212 verbosity = verbosity[1 ],
207- linesearch = OptimKit . HagerZhangLineSearch (verbosity = verbosity[ 2 ], maxiter = 5 )
213+ linesearch = linesearch
208214 );
209215 finalize!,
210216 inner = (_,ξ1,ξ2)-> sum (unbatch (dot (ξ1,ξ2))),
0 commit comments