Skip to content

Commit 7c2d4d9

Browse files
authored
refactor optimization to make it easy to add new backends (#802)
* refactor fit!(::LinearMixedModel) to support a more flexible optimization backend * fix crossref * NEWS and version bump * OptSummary show specialization * update PRIMA backend to use new infrastructure * test other PRIMA optimizers * nlopt for GLMM * prima for GLMM * use scaling * updated show methods * make thinning a noop * BlueStyle * docs update
1 parent 1ea4083 commit 7c2d4d9

21 files changed

+573
-287
lines changed

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
MixedModels v4.30.0 Release Notes
2+
==============================
3+
- Refactor calls to backend optimizer to make it easier to add and use different optimization backends.
4+
The structure of `OptSummary` has been accordingly expanded and `prfit!` has been updated to use this new structure. [#802]
5+
- Make the `thin` argument to `fit!` a no-op. It complicated several bits of logic without having any real performance benefit in the majority of cases. This argument has been replaced with a `fitlog::Bool=false` that determines whether a log is kept.[#802]
6+
17
MixedModels v4.29.1 Release Notes
28
==============================
39
- Populate `optsum` in `prfit!` call. [#801]
@@ -595,3 +601,4 @@ Package dependencies
595601
[#795]: https://github.com/JuliaStats/MixedModels.jl/issues/795
596602
[#799]: https://github.com/JuliaStats/MixedModels.jl/issues/799
597603
[#801]: https://github.com/JuliaStats/MixedModels.jl/issues/801
604+
[#802]: https://github.com/JuliaStats/MixedModels.jl/issues/802

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MixedModels"
22
uuid = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
33
author = ["Phillip Alday <[email protected]>", "Douglas Bates <[email protected]>"]
4-
version = "4.29.1"
4+
version = "4.30.0"
55

66
[deps]
77
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"

docs/src/optimization.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ DisplayAs.Text(ans) # hide
190190
```
191191

192192
More detailed information about the intermediate steps of the nonlinear optimizer can be obtained the `fitlog` field.
193-
By default, `fitlog` contains entries for only the initial and final steps, but additional information about every nth step can be obtained with the `thin` keyword-argument to `fit`, `fit!` and `refit!`:
193+
By default, `fitlog` is not populated, but passing the keyword argument `fitlog=true` to `fit!` and `refit!` will result in it being populated with the values obtained at each step of optimization:
194194

195195
```@example Main
196-
refit!(fm2; thin=1)
197-
fm2.optsum.fitlog[1:10]
196+
refit!(fm2; fitlog=true)
197+
first(fm2.optsum.fitlog, 5)
198198
DisplayAs.Text(ans) # hide
199199
```
200200

ext/MixedModelsPRIMAExt.jl

Lines changed: 100 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,41 @@
11
module MixedModelsPRIMAExt
22

3-
using MixedModels: MixedModels, LinearMixedModel, objective!
4-
using MixedModels: ProgressMeter, ProgressUnknown
3+
using MixedModels
4+
using MixedModels: Statistics
5+
using MixedModels: ProgressMeter, ProgressUnknown, objective!, _objective!
6+
using LinearAlgebra: PosDefException
57
using PRIMA: PRIMA
68

9+
function __init__()
10+
push!(MixedModels.OPTIMIZATION_BACKENDS, :prima)
11+
return nothing
12+
end
13+
14+
const PRIMABackend = Val{:prima}
15+
716
function MixedModels.prfit!(m::LinearMixedModel;
8-
progress::Bool=true,
9-
REML::Bool=m.optsum.REML,
10-
σ::Union{Real,Nothing}=m.optsum.sigma,
11-
thin::Int=1)
12-
optsum = m.optsum
13-
copyto!(optsum.final, optsum.initial)
14-
optsum.REML = REML
15-
optsum.sigma = σ
16-
optsum.finitial = objective!(m, optsum.initial)
17+
kwargs...)
18+
19+
MixedModels.unfit!(m)
20+
m.optsum.optimizer = :bobyqa
21+
m.optsum.backend = :prima
1722

23+
return fit!(m; kwargs...)
24+
end
25+
26+
prima_optimizer!(::Val{:bobyqa}, args...; kwargs...) = PRIMA.bobyqa!(args...; kwargs...)
27+
prima_optimizer!(::Val{:cobyla}, args...; kwargs...) = PRIMA.cobyla!(args...; kwargs...)
28+
prima_optimizer!(::Val{:lincoa}, args...; kwargs...) = PRIMA.lincoa!(args...; kwargs...)
29+
30+
function MixedModels.optimize!(m::LinearMixedModel, ::PRIMABackend;
31+
progress::Bool=true, fitlog::Bool=false, kwargs...)
32+
optsum = m.optsum
1833
prog = ProgressUnknown(; desc="Minimizing", showspeed=true)
19-
# start from zero for the initial call to obj before optimization
20-
iter = 0
21-
fitlog = empty!(optsum.fitlog)
34+
fitlog && empty!(optsum.fitlog)
35+
2236
function obj(x)
23-
iter += 1
24-
val = if isone(iter) && x == optsum.initial
37+
val = if x == optsum.initial
38+
# fast path since we've already evaluated the initial value
2539
optsum.finitial
2640
else
2741
try
@@ -37,19 +51,84 @@ function MixedModels.prfit!(m::LinearMixedModel;
3751
end
3852
end
3953
progress && ProgressMeter.next!(prog; showvalues=[(:objective, val)])
40-
if isone(iter) || iszero(rem(iter, thin))
41-
push!(fitlog, (copy(x), val))
54+
fitlog && push!(optsum.fitlog, (copy(x), val))
55+
return val
56+
end
57+
58+
maxfun = optsum.maxfeval > 0 ? optsum.maxfeval : 500 * length(optsum.initial)
59+
info = prima_optimizer!(Val(optsum.optimizer), obj, optsum.final;
60+
xl=optsum.lowerbd, maxfun,
61+
optsum.rhoend, optsum.rhobeg)
62+
ProgressMeter.finish!(prog)
63+
optsum.feval = info.nf
64+
optsum.fmin = info.fx
65+
optsum.returnvalue = Symbol(info.status)
66+
_check_prima_return(info)
67+
return optsum.final, optsum.fmin
68+
end
69+
70+
function MixedModels.optimize!(m::GeneralizedLinearMixedModel, ::PRIMABackend;
71+
progress::Bool=true, fitlog::Bool=false,
72+
fast::Bool=false, verbose::Bool=false, nAGQ=1,
73+
kwargs...)
74+
optsum = m.optsum
75+
prog = ProgressUnknown(; desc="Minimizing", showspeed=true)
76+
fitlog && empty!(opstum.fitlog)
77+
78+
function obj(x)
79+
val = try
80+
_objective!(m, x, Val(fast); verbose, nAGQ)
81+
catch ex
82+
# this allows us to recover from models where e.g. the link isn't
83+
# as constraining as it should be
84+
ex isa Union{PosDefException,DomainError} || rethrow()
85+
x == optsum.initial && rethrow()
86+
m.optsum.finitial
4287
end
88+
fitlog && push!(optsum.fitlog, (copy(x), val))
89+
verbose && println(round(val; digits=5), " ", x)
90+
progress && ProgressMeter.next!(prog; showvalues=[(:objective, val)])
4391
return val
4492
end
4593

94+
optsum.finitial = _objective!(m, optsum.initial, Val(fast); verbose, nAGQ)
95+
maxfun = optsum.maxfeval > 0 ? optsum.maxfeval : 500 * length(optsum.initial)
96+
scale = if fast
97+
nothing
98+
else
99+
# scale by the standard deviation of the columns of the fixef model matrix
100+
# when including the fixef in the nonlinear opt
101+
sc = [map(std, eachcol(modelmatrix(m))); fill(1, length(m.θ))]
102+
for (i, x) in enumerate(sc)
103+
# for nearly constant things, e.g. intercept, we don't want to scale to zero...
104+
# also, since we're scaling the _parameters_ and not the data,
105+
# we need to invert the scale
106+
sc[i] = ifelse(iszero(x), one(x), inv(x))
107+
end
108+
sc
109+
end
110+
info = prima_optimizer!(Val(optsum.optimizer), obj, optsum.final;
111+
xl=optsum.lowerbd, maxfun,
112+
optsum.rhoend, optsum.rhobeg,
113+
scale)
46114
ProgressMeter.finish!(prog)
47-
info = PRIMA.bobyqa!(obj, optsum.final; xl=m.optsum.lowerbd)
115+
48116
optsum.feval = info.nf
49117
optsum.fmin = info.fx
50118
optsum.returnvalue = Symbol(info.status)
51-
optsum.optimizer = :PRIMA_BOBYQA
52-
return m
119+
_check_prima_return(info)
120+
121+
return optsum.final, optsum.fmin
122+
end
123+
124+
function _check_prima_return(info::PRIMA.Info)
125+
if !PRIMA.issuccess(info)
126+
@warn "PRIMA optimization failure: $(info.status)\n$(PRIMA.reason(info))"
127+
end
128+
129+
return nothing
53130
end
54131

132+
MixedModels.opt_params(::PRIMABackend) = (:rhobeg, :rhoend, :maxfeval)
133+
55134
end # module

src/MixedModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ include("grouping.jl")
205205
include("mimeshow.jl")
206206
include("serialization.jl")
207207
include("profile/profile.jl")
208+
include("nlopt.jl")
208209
include("prima.jl")
209210

210211
# COV_EXCL_START

src/bootstrap.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ function parametricbootstrap(
234234
end
235235
β = convert(Vector{T}, β)
236236
θ = convert(Vector{T}, θ)
237-
# scratch -- note that this is the length of the unpivoted coef vector
237+
# scratch -- note that this is the length of the unpivoted coef vector
238238
βsc = coef(morig)
239239
θsc = zeros(ftype, length(θ))
240240
p = length(βsc)
@@ -254,7 +254,7 @@ function parametricbootstrap(
254254
)
255255
samp = replicate(n; progress) do
256256
simulate!(rng, m; β, σ, θ)
257-
refit!(m; progress=false)
257+
refit!(m; progress=false, fitlog=false)
258258
(
259259
objective=ftype.(m.objective),
260260
σ=ismissing(m.σ) ? missing : ftype(m.σ),

src/generalizedlinearmixedmodel.jl

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ and it may not be shown at all for models that are optimized quickly.
219219
220220
If `verbose` is `true`, then both the intermediate results of both the nonlinear optimization and PIRLS are also displayed on standard output.
221221
222-
At every `thin`th iteration is recorded in `fitlog`, optimization progress is saved in `m.optsum.fitlog`.
222+
The `thin` argument is ignored: it had no impact on the final model fit and the logic around
223+
thinning the `fitlog` was needlessly complicated for a trivial performance gain.
223224
224225
By default, the starting values for model fitting are taken from a (non mixed,
225226
i.e. marginal ) GLM fit. Experience with larger datasets (many thousands of
@@ -247,7 +248,10 @@ function StatsAPI.fit!(
247248
nAGQ::Integer=1,
248249
progress::Bool=true,
249250
thin::Int=typemax(Int),
251+
fitlog::Bool=false,
250252
init_from_lmm=Set(),
253+
backend::Symbol=m.optsum.backend,
254+
optimizer::Symbol=m.optsum.optimizer,
251255
) where {T}
252256
β = copy(m.β)
253257
θ = copy(m.θ)
@@ -277,60 +281,33 @@ function StatsAPI.fit!(
277281
optsum.initial = vcat(β, lm.optsum.final)
278282
optsum.final = copy(optsum.initial)
279283
end
280-
setpar! = fast ? setθ! : setβθ!
281-
prog = ProgressUnknown(; desc="Minimizing", showspeed=true)
282-
# start from zero for the initial call to obj before optimization
283-
iter = 0
284-
fitlog = optsum.fitlog
285-
function obj(x, g)
286-
isempty(g) || throw(ArgumentError("g should be empty for this objective"))
287-
val = try
288-
deviance(pirls!(setpar!(m, x), fast, verbose), nAGQ)
289-
catch ex
290-
# this allows us to recover from models where e.g. the link isn't
291-
# as constraining as it should be
292-
ex isa Union{PosDefException,DomainError} || rethrow()
293-
iter == 1 && rethrow()
294-
m.optsum.finitial
295-
end
296-
iszero(rem(iter, thin)) && push!(fitlog, (copy(x), val))
297-
verbose && println(round(val; digits=5), " ", x)
298-
progress && ProgressMeter.next!(prog; showvalues=[(:objective, val)])
299-
iter += 1
300-
return val
301-
end
302-
opt = Opt(optsum)
303-
NLopt.min_objective!(opt, obj)
304-
optsum.finitial = obj(optsum.initial, T[])
305-
empty!(fitlog)
306-
push!(fitlog, (copy(optsum.initial), optsum.finitial))
307-
fmin, xmin, ret = NLopt.optimize(opt, copyto!(optsum.final, optsum.initial))
308-
ProgressMeter.finish!(prog)
284+
285+
optsum.backend = backend
286+
optsum.optimizer = optimizer
287+
288+
xmin, fmin = optimize!(m; progress, fitlog, fast, verbose, nAGQ)
289+
309290
## check if very small parameter values bounded below by zero can be set to zero
310291
xmin_ = copy(xmin)
311292
for i in eachindex(xmin_)
312293
if iszero(optsum.lowerbd[i]) && zero(T) < xmin_[i] < optsum.xtol_zero_abs
313294
xmin_[i] = zero(T)
314295
end
315296
end
316-
loglength = length(fitlog)
317297
if xmin xmin_
318-
if (zeroobj = obj(xmin_, T[])) (fmin + optsum.ftol_zero_abs)
298+
if (zeroobj = objective!(m, xmin_; nAGQ, fast, verbose))
299+
(fmin + optsum.ftol_zero_abs)
319300
fmin = zeroobj
320301
copyto!(xmin, xmin_)
321-
elseif length(fitlog) > loglength
322-
# remove unused extra log entry
323-
pop!(fitlog)
302+
fitlog && push!(optsum.fitlog, (copy(xmin), fmin))
324303
end
325304
end
305+
326306
## ensure that the parameter values saved in m are xmin
327-
pirls!(setpar!(m, xmin), fast, verbose)
328-
optsum.nAGQ = nAGQ
329-
optsum.feval = opt.numevals
307+
objective!(m, xmin; fast, verbose, nAGQ)
330308
optsum.final = xmin
331309
optsum.fmin = fmin
332-
optsum.returnvalue = ret
333-
_check_nlopt_return(ret)
310+
optsum.nAGQ = nAGQ
334311
return m
335312
end
336313

@@ -540,6 +517,30 @@ function StatsAPI.loglikelihood(m::GeneralizedLinearMixedModel{T}) where {T}
540517
return accum - (mapreduce(u -> sum(abs2, u), +, m.u) + logdet(m)) / 2
541518
end
542519

520+
# Base.Fix1 doesn't forward kwargs
521+
function objective!(m::GeneralizedLinearMixedModel; fast=false, kwargs...)
522+
return x -> _objective!(m, x, Val(fast); kwargs...)
523+
end
524+
525+
function objective!(m::GeneralizedLinearMixedModel{T}, x; fast=false, kwargs...) where {T}
526+
return _objective!(m, x, Val(fast); kwargs...)
527+
end
528+
529+
# normally, it doesn't make sense to move a simple branch to dispatch
530+
# HOWEVER, this winds up getting called in optimization a lot and
531+
# moving this to a realization here allows us to avoid dynamic dispatch on setθ! / setθβ!
532+
function _objective!(
533+
m::GeneralizedLinearMixedModel{T}, x, ::Val{true}; nAGQ=1, verbose=false
534+
) where {T}
535+
return deviance(pirls!(setθ!(m, x), true, verbose), nAGQ)
536+
end
537+
538+
function _objective!(
539+
m::GeneralizedLinearMixedModel{T}, x, ::Val{false}; nAGQ=1, verbose=false
540+
) where {T}
541+
return deviance(pirls!(setβθ!(m, x), false, verbose), nAGQ)
542+
end
543+
543544
function Base.propertynames(m::GeneralizedLinearMixedModel, private::Bool=false)
544545
return (
545546
:A,

0 commit comments

Comments
 (0)