Skip to content

Commit 2e3051a

Browse files
dmbatespalday
authored andcommitted
correct coeftable for GLMMs, add tests (#308)
* correct coeftable for GLMMs, add tests * Move methods for abstract MixedModel struct to separate file. (cherry picked from commit 5fb6f65)
1 parent 7dbd4c2 commit 2e3051a

File tree

5 files changed

+133
-56
lines changed

5 files changed

+133
-56
lines changed

src/MixedModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ include("linearmixedmodel.jl")
117117
include("gausshermite.jl")
118118
include("generalizedlinearmixedmodel.jl")
119119
include("mixed.jl")
120+
include("mixedmodel.jl")
120121
include("linalg/statschol.jl")
121122
include("linalg/cholUnblocked.jl")
122123
include("linalg/rankUpdate.jl")

src/generalizedlinearmixedmodel.jl

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ struct GeneralizedLinearMixedModel{T <: AbstractFloat} <: MixedModel{T}
5353
mult::Vector{T}
5454
end
5555

56+
function StatsBase.coeftable(m::GeneralizedLinearMixedModel)
57+
co = fixef(m)
58+
se = stderror(m)
59+
z = co ./ se
60+
pvalue = ccdf.(Chisq(1), abs2.(z))
61+
CoefTable(
62+
hcat(co, se, z, pvalue),
63+
["Estimate", "Std.Error", "z value", "P(>|z|)"],
64+
coefnames(m),
65+
4, # pvalcol
66+
3, # teststatcol
67+
)
68+
end
69+
5670
"""
5771
deviance(m::GeneralizedLinearMixedModel{T}, nAGQ=1)::T where {T}
5872
@@ -68,7 +82,7 @@ function StatsBase.deviance(m::GeneralizedLinearMixedModel{T}, nAGQ=1) where {T}
6882
u = vec(first(m.u))
6983
u₀ = vec(first(m.u₀))
7084
copyto!(u₀, u)
71-
ra = RaggedArray(m.resp.devresid, first(m.LMM.reterms).refs)
85+
ra = RaggedArray(m.resp.devresid, first(m.LMM.allterms).refs)
7286
devc0 = sum!(map!(abs2, m.devc0, u), ra) # the deviance components at z = 0
7387
sd = map!(inv, m.sd, m.LMM.L[Block(1,1)].diag)
7488
mult = fill!(m.mult, 0)
@@ -105,8 +119,17 @@ function deviance!(m::GeneralizedLinearMixedModel, nAGQ=1)
105119
deviance(m, nAGQ)
106120
end
107121

108-
GLM.dispersion(m::GeneralizedLinearMixedModel, sqr::Bool=false) =
109-
dispersion(m.resp, dof_residual(m), sqr)
122+
function GLM.dispersion(m::GeneralizedLinearMixedModel{T}, sqr::Bool = false) where {T}
123+
# adapted from GLM.dispersion(::AbstractGLM, ::Bool)
124+
# TODO: PR for a GLM.dispersion(resp::GLM.GlmResp, dof_residual::Int, sqr::Bool)
125+
r = m.resp
126+
if dispersion_parameter(r.d)
127+
s = sum(wt * abs2(re) for (wt, re) in zip(r.wrkwt, r.wrkresid)) / dof_residual(m)
128+
sqr ? s : sqrt(s)
129+
else
130+
one(T)
131+
end
132+
end
110133

111134
GLM.dispersion_parameter(m::GeneralizedLinearMixedModel) = dispersion_parameter(m.resp.d)
112135

@@ -295,7 +318,11 @@ function Base.getproperty(m::GeneralizedLinearMixedModel, s::Symbol)
295318
m.β
296319
elseif s (, :sigma)
297320
sdest(m)
298-
elseif s (:A, :L, , :lowerbd, :optsum, :X, :reterms, :feterms, :formula, :σs, :σρs)
321+
elseif s == :σs
322+
σs(m)
323+
elseif s == :σρs
324+
σρs(m)
325+
elseif s (:A, :L, , :lowerbd, :corr, :PCA, :rePCA, :optsum, :X, :reterms, :feterms, :formula)
299326
getproperty(m.LMM, s)
300327
elseif s == :y
301328
m.resp.y
@@ -305,18 +332,17 @@ function Base.getproperty(m::GeneralizedLinearMixedModel, s::Symbol)
305332
end
306333

307334
function StatsBase.loglikelihood(m::GeneralizedLinearMixedModel{T}) where {T}
308-
accum = zero(T)
335+
r = m.resp
309336
D = Distribution(m.resp)
310-
if D <: Binomial
311-
for (μ, y, n) in zip(m.resp.mu, m.resp.y, m.wt)
312-
accum += logpdf(D(round(Int, n), μ), round(Int, y * n))
337+
accum = (
338+
if D <: Binomial
339+
sum(logpdf(D(round(Int, n), μ), round(Int, y * n))
340+
for (μ, y, n) in zip(r.mu, r.y, m.wt))
341+
else
342+
sum(logpdf(D(μ), y) for (μ, y) in zip(r.mu, r.y))
313343
end
314-
else
315-
for (μ, y) in zip(m.resp.mu, m.resp.y)
316-
accum += logpdf(D(μ), y)
317-
end
318-
end
319-
accum - (mapreduce(u -> sum(abs2, u), + , m.u) + logdet(m)) / 2
344+
)
345+
accum - (sum(sum(abs2, u) for u in m.u) + logdet(m)) / 2
320346
end
321347

322348
StatsBase.nobs(m::GeneralizedLinearMixedModel) = length(m.η)
@@ -467,13 +493,14 @@ varest(m::GeneralizedLinearMixedModel{T}) where {T} = one(T)
467493
for f in (
468494
:describeblocks,
469495
:feL,
496+
:fetrm,
470497
:(LinearAlgebra.logdet),
471498
:lowerbd,
499+
:PCA,
500+
:rePCA,
501+
:(StatsBase.coefnames),
472502
:(StatsModels.modelmatrix),
473-
:(StatsBase.vcov),
474-
:σs,
475-
:σρs,
476-
)
503+
)
477504
@eval begin
478505
$f(m::GeneralizedLinearMixedModel) = $f(m.LMM)
479506
end

src/linearmixedmodel.jl

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ function StatsBase.coeftable(m::MixedModel)
126126
first(m.feterms).cnames, 4)
127127
end
128128

129-
"""
130-
cond(m::MixedModel)
131-
132-
Return a vector of condition numbers of the λ matrices for the random-effects terms
133-
"""
134-
LinearAlgebra.cond(m::MixedModel) = cond.(m.λ)
135-
136129
"""
137130
condVar(m::LinearMixedModel)
138131
@@ -231,12 +224,29 @@ function StatsBase.dof_residual(m::LinearMixedModel)::Int
231224
end
232225

233226
"""
234-
feL(m::MixedModel)
227+
feind(m::LinearMixedModel)
228+
229+
An internal utility to return the index in `m.allterms` of the fixed-effects term.
230+
"""
231+
feind(m::LinearMixedModel) = findfirst(Base.Fix2(isa, FeMat), m.allterms)
232+
233+
"""
234+
feL(m::LinearMixedModel)
235235
236236
Return the lower Cholesky factor for the fixed-effects parameters, as an `LowerTriangular`
237237
`p × p` matrix.
238238
"""
239-
feL(m::LinearMixedModel) = LowerTriangular(m.L.blocks[end - 1, end - 1])
239+
function feL(m::LinearMixedModel)
240+
k = feind(m)
241+
LowerTriangular(m.L.blocks[k, k])
242+
end
243+
244+
"""
245+
fetrm(m::LinearMixedModel)
246+
247+
Return the fixed-effects term from `m.allterms`
248+
"""
249+
fetrm(m::LinearMixedModel) = m.allterms[feind(m)]
240250

241251
"""
242252
fit!(m::LinearMixedModel[; verbose::Bool=false, REML::Bool=false])
@@ -636,16 +646,13 @@ function Base.show(io::IO, m::LinearMixedModel)
636646
show(io,coeftable(m))
637647
end
638648

639-
function σs(m::LinearMixedModel)
640-
σ = sdest(m)
641-
NamedTuple{fnames(m)}(((σs(t, σ) for t in m.reterms)...,))
642-
end
643-
644-
function σρs(m::LinearMixedModel)
645-
σ = sdest(m)
646-
NamedTuple{fnames(m)}(((σρs(t, σ) for t in m.reterms)...,))
647-
end
649+
"""
650+
size(m::LinearMixedModel)
648651
652+
Returns the size of a mixed model as a tuple of length four:
653+
the number of observations, the number of (non-singular) fixed-effects parameters,
654+
the number of conditional modes (random effects), the number of grouping variables
655+
"""
649656
function Base.size(m::LinearMixedModel)
650657
n, p = size(first(m.feterms))
651658
n, p, sum(size.(m.reterms, 2)), length(m.reterms)
@@ -660,10 +667,26 @@ This value is the contents of the `1 × 1` bottom right block of `m.L`
660667
"""
661668
sqrtpwrss(m::LinearMixedModel) = first(m.L.blocks[end, end])
662669

670+
671+
"""
672+
ssqdenom(m::LinearMixedModel)
673+
Return the denominator for penalized sums-of-squares.
674+
For MLE, this value is the number of observations. For REML, this
675+
value is the number of observations minus the rank of the fixed-effects matrix.
676+
The difference is analagous to the use of n or n-1 in the denominator when
677+
calculating the variance.
678+
"""
679+
function ssqdenom(m::LinearMixedModel)::Int
680+
(n, p, q, k) = size(m)
681+
m.optsum.REML ? n - p : n
682+
end
683+
663684
"""
664685
std(m::MixedModel)
665686
666687
Return the estimated standard deviations of the random effects as a `Vector{Vector{T}}`.
688+
689+
FIXME: This uses an old convention of isfinite(sdest(m)). Probably drop in favor of m.σs
667690
"""
668691
function Statistics.std(m::LinearMixedModel)
669692
rl = rowlengths.(m.reterms)
@@ -739,25 +762,7 @@ end
739762
740763
Returns the estimate of σ², the variance of the conditional distribution of Y given B.
741764
"""
742-
varest(m::LinearMixedModel) = pwrss(m) / dof_residual(m)
743-
744-
function StatsBase.vcov(m::LinearMixedModel{T}) where {T}
745-
Xtrm = first(m.feterms)
746-
iperm = invperm(Xtrm.piv)
747-
p = length(iperm)
748-
r = Xtrm.rank
749-
Linv = inv(feL(m))
750-
permvcov = varest(m) * (Linv'Linv)
751-
if p == Xtrm.rank
752-
permvcov[iperm, iperm]
753-
else
754-
covmat = fill(zero(T)/zero(T), (p, p))
755-
for j in 1:r, i in 1:r
756-
covmat[i,j] = permvcov[i, j]
757-
end
758-
covmat[iperm, iperm]
759-
end
760-
end
765+
varest(m::LinearMixedModel) = pwrss(m) / ssqdenom(m)
761766

762767
"""
763768
zerocorr!(m::LinearMixedModel[, trmnms::Vector{Symbol}])

src/mixedmodel.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
"""
3+
cond(m::MixedModel)
4+
5+
Return a vector of condition numbers of the λ matrices for the random-effects terms
6+
"""
7+
LinearAlgebra.cond(m::MixedModel) = cond.(m.λ)
8+
9+
function σs(m::MixedModel)
10+
σ = dispersion(m)
11+
NamedTuple{fnames(m)}(((σs(t, σ) for t in m.reterms)...,))
12+
end
13+
14+
function σρs(m::MixedModel)
15+
σ = dispersion(m)
16+
NamedTuple{fnames(m)}(((σρs(t, σ) for t in m.reterms)...,))
17+
end
18+
19+
"""
20+
vcov(m::LinearMixedModel)
21+
22+
Returns the variance-covariance matrix of the fixed effects.
23+
If `corr=true`, then correlation of fixed effects is returned instead.
24+
"""
25+
function StatsBase.vcov(m::MixedModel; corr=false)
26+
Xtrm = fetrm(m)
27+
iperm = invperm(Xtrm.piv)
28+
p = length(iperm)
29+
r = Xtrm.rank
30+
Linv = inv(feL(m))
31+
T = eltype(Linv)
32+
permvcov = dispersion(m, true) * (Linv'Linv)
33+
if p == Xtrm.rank
34+
vv = permvcov[iperm, iperm]
35+
else
36+
covmat = fill(zero(T) / zero(T), (p, p))
37+
for j = 1:r, i = 1:r
38+
covmat[i, j] = permvcov[i, j]
39+
end
40+
vv = covmat[iperm, iperm]
41+
end
42+
43+
corr ? StatsBase.cov2cor!(vv, stderror(m)) : vv
44+
end

test/pirls.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ end
5151
gm2 = fit(MixedModel, @formula(prop ~ 1 + p + (1|h)), cbpp, Binomial(), wts = cbpp[!,:s])
5252
@test isapprox(deviance(gm2,true), 100.09585619324639, atol=0.0001)
5353
@test isapprox(sum(abs2, gm2.u[1]), 9.723175126731014, atol=0.0001)
54-
@test isapprox(logdet(gm2), 16.90099, atol=0.0001)
54+
@test isapprox(logdet(gm2), 16.90099, atol=0.001)
5555
@test isapprox(sum(gm2.resp.devresid), 73.47179193718736, atol=0.001)
5656
@test isapprox(loglikelihood(gm2), -92.02628186555876, atol=0.001)
5757
@test isnan(sdest(gm2))

0 commit comments

Comments
 (0)