Skip to content

Commit d6c4dbb

Browse files
committed
make linesearch customaizable
1 parent aac575e commit d6c4dbb

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

src/maximization.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ function MAP_joint(
126126
ϕtol = nothing,
127127
progress::Bool = true,
128128
verbosity = (0,0),
129+
linesearch = OptimKit.HagerZhangLineSearch(verbosity=verbosity[2], maxiter=5),
129130
conjgrad_kwargs = (tol=1e-1,nsteps=500),
130131
quasi_sample = false,
131132
preconditioner = :diag,
@@ -163,8 +164,11 @@ function MAP_joint(
163164
push!(history, select((;f,f°,ϕ,∇ϕ_lnP=nothing,χ²=nothing,lnP=nothing,argmaxf_lnP_history), history_keys))
164165

165166
# objective function (with gradient) to maximize
166-
@⌛ function objective(ϕ)
167-
@⌛(sum(unbatch(-2lnP(:mix,f°,ϕ,dsθ)))), @⌛(gradient->-2lnP(:mix,f°,ϕ,dsθ), ϕ)[1])
167+
@⌛ function objective(ϕ; need_gradient=true)
168+
(
169+
@⌛(sum(unbatch(-2lnP(:mix,f°,ϕ,dsθ)))),
170+
need_gradient ? @⌛(gradient->-2lnP(:mix,f°,ϕ,dsθ), ϕ)[1]) : nothing
171+
)
168172
end
169173
# function to compute after each optimization iteration, which
170174
# recomputes the best-fit f given the current ϕ
@@ -182,7 +186,9 @@ function MAP_joint(
182186
∇ϕ_lnP .= @⌛ gradient->-2lnP(:mix,f°,ϕ,dsθ), ϕ)[1]
183187
χ²s = @⌛ -2lnP(:mix,f°,ϕ,dsθ)
184188
χ² = sum(unbatch(χ²s))
185-
next!(pbar, showvalues=[("step",i), ("χ²",χ²s), ("Ncg",length(argmaxf_lnP_history))])
189+
values = [("step",i), ("χ²",χ²s), ("Ncg",length(argmaxf_lnP_history))]
190+
hasproperty(linesearch, ) && push!(values, ("α", linesearch.α))
191+
next!(pbar, showvalues=values)
186192
push!(history, select((;f,f°,ϕ,∇ϕ_lnP,χ²,lnP=-χ²/2,argmaxf_lnP_history), history_keys))
187193
if (
188194
!isnothing(ϕtol) &&
@@ -204,7 +210,7 @@ function MAP_joint(
204210
lbfgs_rank;
205211
maxiter = nsteps,
206212
verbosity = verbosity[1],
207-
linesearch = OptimKit.HagerZhangLineSearch(verbosity=verbosity[2], maxiter=5)
213+
linesearch = linesearch
208214
);
209215
finalize!,
210216
inner = (_,ξ1,ξ2)->sum(unbatch(dot(ξ1,ξ2))),

src/plotting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function _plot(f, ax, k, title, vlim, vscale, cmap; cbar=true, units=:deg, tickl
137137

138138
# annonate
139139
if cbar
140-
colorbar(cax,ax=ax)
140+
colorbar(cax,ax=ax,pad=0.01)
141141
end
142142
ax.set_title(title, y=1)
143143
if ticklabels

0 commit comments

Comments
 (0)