Skip to content

Commit f03c65b

Browse files
authored
restore and save optsum for GLMM (#791)
* start work on restoreoptsum for glmm * unfit! now wipes more state * restore and save optsum for GLMM * NEWS, version bump * Blue Style * undo change to LMM? * news tweak * He's always watching your style Co-authored-by: Alex Arslan <[email protected]> * in the dark Co-authored-by: Alex Arslan <[email protected]> * tweaks Co-authored-by: Alex Arslan <[email protected]> * test update * :Facepalm:
1 parent c0b9e2d commit f03c65b

File tree

7 files changed

+117
-26
lines changed

7 files changed

+117
-26
lines changed

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
MixedModels v4.27.0 Release Notes
2+
==============================
3+
- `saveoptsum` and `restoreoptsum!` now support `GeneralizedLinearMixedModel`s [#791]
4+
- `unfit!(::GeneralizedLinearMixedModel)` (called internally by `refit!`) now does a better job of fully resetting the model state [#791]
5+
16
MixedModels v4.26.1 Release Notes
27
==============================
38
- lower and upper edges of profile confidence intervals for REML-fitted models are no longer flipped [#785]
@@ -569,3 +574,4 @@ Package dependencies
569574
[#778]: https://github.com/JuliaStats/MixedModels.jl/issues/778
570575
[#783]: https://github.com/JuliaStats/MixedModels.jl/issues/783
571576
[#785]: https://github.com/JuliaStats/MixedModels.jl/issues/785
577+
[#791]: https://github.com/JuliaStats/MixedModels.jl/issues/791

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MixedModels"
22
uuid = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
33
author = ["Phillip Alday <[email protected]>", "Douglas Bates <[email protected]>", "Jose Bayoan Santiago Calderon <[email protected]>"]
4-
version = "4.26.1"
4+
version = "4.27.0"
55

66
[deps]
77
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"

src/generalizedlinearmixedmodel.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,21 @@ function stderror!(v::AbstractVector{T}, m::GeneralizedLinearMixedModel{T}) wher
767767
end
768768

769769
function unfit!(model::GeneralizedLinearMixedModel{T}) where {T}
770-
deviance!(model, 1)
771770
reevaluateAend!(model.LMM)
772771

773772
reterms = model.LMM.reterms
774773
optsum = model.LMM.optsum
775774
# we need to reset optsum so that it
776775
# plays nice with the modifications fit!() does
777776
optsum.lowerbd = mapfoldl(lowerbd, vcat, reterms)
778-
optsum.initial = mapfoldl(getθ, vcat, reterms)
777+
# for variances (bounded at zero), we have ones, while
778+
# for everything else (bounded at -Inf), we have zeros
779+
optsum.initial = map(T iszero, optsum.lowerbd)
779780
optsum.final = copy(optsum.initial)
780781
optsum.xtol_abs = fill!(copy(optsum.initial), 1.0e-10)
781782
optsum.initial_step = T[]
782783
optsum.feval = -1
784+
deviance!(model, 1)
783785

784786
return model
785787
end

src/optsummary.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,9 @@ function _check_nlopt_return(ret, failure_modes=_NLOPT_FAILURE_MODES)
162162
@warn("NLopt optimization failure: $ret")
163163
end
164164
end
165+
166+
function Base.:(==)(o1::OptSummary{T}, o2::OptSummary{T}) where {T}
167+
return all(fieldnames(OptSummary)) do fn
168+
return getfield(o1, fn) == getfield(o2, fn)
169+
end
170+
end

src/serialization.jl

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,78 @@
11
"""
2-
restoreoptsum!(m::LinearMixedModel, io::IO; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
3-
restoreoptsum!(m::LinearMixedModel, filename; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
2+
restoreoptsum!(m::MixedModel, io::IO; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
3+
restoreoptsum!(m::MixedModel, filename; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
44
55
Read, check, and restore the `optsum` field from a JSON stream or filename.
66
"""
7+
function restoreoptsum!(m::MixedModel, filename; kwargs...)
8+
return open(filename, "r") do io
9+
return restoreoptsum!(m, io; kwargs...)
10+
end
11+
end
12+
713
function restoreoptsum!(
814
m::LinearMixedModel{T}, io::IO; atol::Real=zero(T),
915
rtol::Real=atol > 0 ? zero(T) : eps(T),
16+
) where {T}
17+
dict = JSON3.read(io)
18+
ops = restoreoptsum!(m.optsum, dict)
19+
for (par, obj_at_par) in (:initial => :finitial, :final => :fmin)
20+
if !isapprox(
21+
objective(updateL!(setθ!(m, getfield(ops, par)))), getfield(ops, obj_at_par);
22+
rtol, atol,
23+
)
24+
throw(
25+
ArgumentError(
26+
"model at $par does not match stored $obj_at_par within atol=$atol, rtol=$rtol"
27+
),
28+
)
29+
end
30+
end
31+
return m
32+
end
33+
34+
function restoreoptsum!(
35+
m::GeneralizedLinearMixedModel{T}, io::IO; atol::Real=zero(T),
36+
rtol::Real=atol > 0 ? zero(T) : eps(T),
1037
) where {T}
1138
dict = JSON3.read(io)
1239
ops = m.optsum
40+
41+
# need to accommodate fast and slow fits
42+
resize!(ops.initial, length(dict.initial))
43+
resize!(ops.final, length(dict.final))
44+
45+
theta_beta_len = length(m.θ) + length(m.β)
46+
if length(dict.initial) == theta_beta_len # fast=false
47+
if length(ops.lowerbd) == length(m.θ)
48+
prepend!(ops.lowerbd, fill(-Inf, length(m.β)))
49+
end
50+
setpar! = setβθ!
51+
varyβ = false
52+
else # fast=true
53+
setpar! = setθ!
54+
varyβ = true
55+
if length(ops.lowerbd) != length(m.θ)
56+
deleteat!(ops.lowerbd, 1:length(m.β))
57+
end
58+
end
59+
restoreoptsum!(ops, dict)
60+
for (par, obj_at_par) in (:initial => :finitial, :final => :fmin)
61+
if !isapprox(
62+
deviance(pirls!(setpar!(m, getfield(ops, par)), varyβ), dict.nAGQ),
63+
getfield(ops, obj_at_par); rtol, atol,
64+
)
65+
throw(
66+
ArgumentError(
67+
"model at $par does not match stored $obj_at_par within atol=$atol, rtol=$rtol"
68+
),
69+
)
70+
end
71+
end
72+
return m
73+
end
74+
75+
function restoreoptsum!(ops::OptSummary{T}, dict::AbstractDict) where {T}
1376
allowed_missing = (
1477
:lowerbd, # never saved, -Inf not allowed in JSON
1578
:xtol_zero_abs, # added in v4.25.0
@@ -27,7 +90,9 @@ function restoreoptsum!(
2790
if length(setdiff(allowed_missing, keys(dict))) > 1 # 1 because :lowerbd
2891
@warn "optsum was saved with an older version of MixedModels.jl: consider resaving."
2992
end
93+
3094
if any(ops.lowerbd .> dict.initial) || any(ops.lowerbd .> dict.final)
95+
@debug "" ops.lowerbd dict.initial dict.final
3196
throw(ArgumentError("initial or final parameters in io do not satisfy lowerbd"))
3297
end
3398
for fld in (:feval, :finitial, :fmin, :ftol_rel, :ftol_abs, :maxfeval, :nAGQ, :REML)
@@ -37,13 +102,6 @@ function restoreoptsum!(
37102
ops.xtol_rel = copy(dict.xtol_rel)
38103
copyto!(ops.initial, dict.initial)
39104
copyto!(ops.final, dict.final)
40-
for (v, f) in (:initial => :finitial, :final => :fmin)
41-
if !isapprox(
42-
objective(updateL!(setθ!(m, getfield(ops, v)))), getfield(ops, f); rtol, atol
43-
)
44-
throw(ArgumentError("model m at $v does not give stored $f"))
45-
end
46-
end
47105
ops.optimizer = Symbol(dict.optimizer)
48106
ops.returnvalue = Symbol(dict.returnvalue)
49107
# compatibility with fits saved before the introduction of various extensions
@@ -59,30 +117,23 @@ function restoreoptsum!(
59117
else
60118
[(convert(Vector{T}, first(entry)), T(last(entry))) for entry in fitlog]
61119
end
62-
return m
63-
end
64-
65-
function restoreoptsum!(m::LinearMixedModel{T}, filename; kwargs...) where {T}
66-
open(filename, "r") do io
67-
restoreoptsum!(m, io; kwargs...)
68-
end
120+
return ops
69121
end
70122

71123
"""
72-
saveoptsum(io::IO, m::LinearMixedModel)
73-
saveoptsum(filename, m::LinearMixedModel)
124+
saveoptsum(io::IO, m::MixedModel)
125+
saveoptsum(filename, m::MixedModel)
74126
75127
Save `m.optsum` (w/o the `lowerbd` field) in JSON format to an IO stream or a file
76128
77129
The reason for omitting the `lowerbd` field is because it often contains `-Inf`
78130
values that are not allowed in JSON.
79131
"""
80-
saveoptsum(io::IO, m::LinearMixedModel) = JSON3.write(io, m.optsum)
81-
function saveoptsum(filename, m::LinearMixedModel)
132+
saveoptsum(io::IO, m::MixedModel) = JSON3.write(io, m.optsum)
133+
function saveoptsum(filename, m::MixedModel)
82134
open(filename, "w") do io
83135
saveoptsum(io, m)
84136
end
85137
end
86138

87-
# TODO: write methods for GLMM
88139
# TODO, maybe: something nice for the MixedModelBootstrap

test/pirls.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,29 @@ end
239239
@test isapprox(first(gm5.β), -0.13860166843315044, atol=1.e-3)
240240
@test isapprox(last(gm5.β), -0.034414458364713504, atol=1.e-3)
241241
end
242+
243+
@testset "GLMM saveoptsum" begin
244+
cbpp = dataset(:cbpp)
245+
gm_original = GeneralizedLinearMixedModel(first(gfms[:cbpp]), cbpp, Binomial(); wts=cbpp.hsz)
246+
gm_restored = GeneralizedLinearMixedModel(first(gfms[:cbpp]), cbpp, Binomial(); wts=cbpp.hsz)
247+
fit!(gm_original; progress=false, nAGQ=1)
248+
249+
io = IOBuffer()
250+
251+
saveoptsum(seekstart(io), gm_original)
252+
restoreoptsum!(gm_restored, seekstart(io))
253+
@test gm_original.optsum == gm_restored.optsum
254+
@test deviance(gm_original) deviance(gm_restored)
255+
256+
refit!(gm_original; progress=false, nAGQ=3)
257+
saveoptsum(seekstart(io), gm_original)
258+
restoreoptsum!(gm_restored, seekstart(io))
259+
@test gm_original.optsum == gm_restored.optsum
260+
@test deviance(gm_original) deviance(gm_restored)
261+
262+
refit!(gm_original; progress=false, fast=true)
263+
saveoptsum(seekstart(io), gm_original)
264+
restoreoptsum!(gm_restored, seekstart(io))
265+
@test gm_original.optsum == gm_restored.optsum
266+
@test deviance(gm_original) deviance(gm_restored)
267+
end

test/pls.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ end
530530
fm_mod = deepcopy(fm)
531531
fm_mod.optsum.fmin += 1
532532
saveoptsum(seekstart(io), fm_mod)
533-
@test_throws(ArgumentError("model m at final does not give stored fmin"),
534-
restoreoptsum!(m, seekstart(io)))
533+
@test_throws(ArgumentError("model at final does not match stored fmin within atol=0.0, rtol=1.0e-8"),
534+
restoreoptsum!(m, seekstart(io); atol=0.0, rtol=1e-8))
535535
restoreoptsum!(m, seekstart(io); atol=1)
536536
@test m.optsum.fmin - fm.optsum.fmin 1
537537

0 commit comments

Comments
 (0)