Skip to content

Commit 553e39f

Browse files
dmbatespalday
andauthored
Tighten code, increase test coverage (#301)
* Tighten code, increase test coverage * relax comparison tolerance * Just-in-time model fits in testing * Revise bounds checking * Revise issingular, tests and fit! logic. * Clean up dependencies * Add tests of show with r.e. correlations * Correct the name of an error condition Co-authored-by: Phillip Alday <[email protected]>
1 parent dec030a commit 553e39f

File tree

11 files changed

+166
-142
lines changed

11 files changed

+166
-142
lines changed

src/MixedModels.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ export @formula,
5656
RandomEffectsTerm,
5757
ReMat,
5858
SqrtLink,
59-
TestData,
6059
UniformBlockDiagonal,
6160
VarCorr,
6261

@@ -84,6 +83,7 @@ export @formula,
8483
GHnorm,
8584
issingular,
8685
leverage,
86+
logdet,
8787
loglikelihood,
8888
lowerbd,
8989
nobs,
@@ -93,6 +93,7 @@ export @formula,
9393
predict,
9494
pwrss,
9595
ranef,
96+
rank,
9697
refit!,
9798
replicate,
9899
residuals,

src/arraytypes.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ function UniformBlockDiagonal(dat::Array{T,3}) where {T}
1515
)
1616
end
1717

18+
function Base.axes(A::UniformBlockDiagonal)
19+
m, n, l = size(A.data)
20+
(Base.OneTo(m * l), Base.OneTo(n * l))
21+
end
22+
1823
function Base.copyto!(dest::UniformBlockDiagonal{T}, src::UniformBlockDiagonal{T}) where {T}
1924
sdat = src.data
2025
ddat = dest.data
@@ -42,10 +47,9 @@ function Base.copyto!(dest::Matrix{T}, src::UniformBlockDiagonal{T}) where {T}
4247
end
4348

4449
function Base.getindex(A::UniformBlockDiagonal{T}, i::Int, j::Int) where {T}
50+
@boundscheck checkbounds(A, i, j)
4551
Ad = A.data
4652
m, n, l = size(Ad)
47-
(0 < i l * m && 0 < j l * n) ||
48-
throw(IndexError("attempt to access $(l*m) × $(l*n) array at index [$i, $j]"))
4953
iblk, ioffset = divrem(i - 1, m)
5054
jblk, joffset = divrem(j - 1, n)
5155
iblk == jblk ? Ad[ioffset+1, joffset+1, iblk+1] : zero(T)

src/blockdescription.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ shorttype(::Diagonal,::Diagonal) = "Diagonal"
4343
shorttype(::Diagonal,::Matrix) = "Diag/Dense"
4444
shorttype(::Matrix,::Matrix) = "Dense"
4545
shorttype(::SparseMatrixCSC,::SparseMatrixCSC) = "Sparse"
46+
shorttype(::SparseMatrixCSC,::Matrix) = "Sparse/Dense"
47+
4648

4749
function Base.show(io::IO, ::MIME"text/plain", b::BlockDescription)
4850
rowwidth = max(maximum(ndigits, b.blkrows) + 1, 5)

src/bootstrap.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,6 @@ end
180180

181181
issingular(bsamp::MixedModelBootstrap) = map-> any.≈ bsamp.lowerbd), bsamp.θ)
182182

183-
ppoints(n::Integer) = inv(2n):inv(n):1
184-
185183
function Base.propertynames(bsamp::MixedModelBootstrap)
186184
[:allpars, :objective, , , , :σs, , :inds, :lowerbd, :bstr, :fcnames]
187185
end

src/likelihoodratiotest.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,6 @@ function _likelihoodratiotest(m::Vararg{T}) where T <: MixedModel
9191
)
9292
end
9393

94-
function _array_union_nothing(arr::Array{T}) where T
95-
Array{Union{T,Nothing}}(arr)
96-
end
97-
98-
function _prepend_0(arr::Array{T}) where T
99-
pushfirst!(copy(arr), -zero(T))
100-
end
101-
10294
function Base.show(io::IO, lrt::LikelihoodRatioTest; digits=2)
10395
println(io, "Model Formulae")
10496

src/linalg.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,6 @@ function LinearAlgebra.mul!(
2525
C
2626
end
2727

28-
LinearAlgebra.mul!(
29-
C::Matrix{T},
30-
A::BlockedSparse{T},
31-
adjB::Adjoint{T,<:Matrix{T}},
32-
α::Number,
33-
β::Number,
34-
) where {T} = mul!(C, A.cscmat, adjB, α, β)
35-
36-
LinearAlgebra.mul!(
37-
C::BlockedSparse{T},
38-
A::BlockedSparse{T},
39-
adjB::Adjoint{T,<:BlockedSparse{T}},
40-
α::Number,
41-
β::Number,
42-
) where {T} = mul!(C.cscmat, A.cscmat, adjB.parent.cscmat', α, β)
43-
4428
LinearAlgebra.mul!(
4529
C::StridedVecOrMat{T},
4630
A::StridedVecOrMat{T},

src/linalg/rankUpdate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,10 @@ function rankUpdate!(
127127
C
128128
end
129129

130+
#=
130131
rankUpdate!(C::HermOrSym{T,Matrix{T}}, A::BlockedSparse{T}, α = true) where {T} =
131132
rankUpdate!(C, A.cscmat, α)
133+
=#
132134

133135
function rankUpdate!(
134136
C::Diagonal{T,S},

src/linearmixedmodel.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,17 +303,15 @@ objective and the parameters are printed on stdout at each function evaluation.
303303
function fit!(m::LinearMixedModel{T}; verbose::Bool = false, REML::Bool = false) where {T}
304304
optsum = m.optsum
305305
opt = Opt(optsum)
306-
feval = 0
307306
optsum.REML = REML
308307
function obj(x, g)
309-
isempty(g) || error("gradient not defined")
310-
feval += 1
308+
isempty(g) || throw(ArgumentError("g should be empty for this objective"))
311309
val = objective(updateL!(setθ!(m, x)))
312-
feval == 1 && (optsum.finitial = val)
313-
verbose && println("f_", feval, ": ", round(val, digits = 5), " ", x)
310+
verbose && println(round(val, digits = 5), " ", x)
314311
val
315312
end
316313
NLopt.min_objective!(opt, obj)
314+
optsum.finitial = obj(optsum.initial, T[])
317315
fmin, xmin, ret = NLopt.optimize!(opt, copyto!(optsum.final, optsum.initial))
318316
## check if small non-negative parameter values can be set to zero
319317
xmin_ = copy(xmin)
@@ -332,12 +330,12 @@ function fit!(m::LinearMixedModel{T}; verbose::Bool = false, REML::Bool = false)
332330
## ensure that the parameter values saved in m are xmin
333331
updateL!(setθ!(m, xmin))
334332

335-
optsum.feval = feval
333+
optsum.feval = opt.numevals
336334
optsum.final = xmin
337335
optsum.fmin = fmin
338336
optsum.returnvalue = ret
339337
ret == :ROUNDOFF_LIMITED && @warn("NLopt was roundoff limited")
340-
if ret [:FAILURE, :INVALID_ARGS, :OUT_OF_MEMORY, :FORCED_STOP, :MAXFEVAL_REACHED]
338+
if ret [:FAILURE, :INVALID_ARGS, :OUT_OF_MEMORY, :FORCED_STOP, :MAXEVAL_REACHED]
341339
@warn("NLopt optimization failure: $ret")
342340
end
343341
m
@@ -474,8 +472,10 @@ end
474472
issingular(m::LinearMixedModel, θ=m.θ)
475473
476474
Test whether the model `m` is singular if the parameter vector is `θ`.
475+
476+
Equality comparisons are used b/c small non-negative θ values are replaced by 0 in `fit!`.
477477
"""
478-
issingular(m::LinearMixedModel, θ=m.θ) = any(isapprox.(lowerbd(m), θ))
478+
issingular(m::LinearMixedModel, θ=m.θ) = any(lowerbd(m) .== θ)
479479

480480
function StatsBase.leverage(m::LinearMixedModel{T}) where {T}
481481
# This can be done more efficiently but reusing existing tools is easier.
@@ -506,11 +506,11 @@ end
506506
lowerbd(m::LinearMixedModel) = m.optsum.lowerbd
507507

508508
function StatsBase.modelmatrix(m::LinearMixedModel)
509-
fetrm = first(m.feterms)
510-
if fetrm.rank == size(fetrm, 2)
511-
fetrm.x
509+
fe = fetrm(m)
510+
if fe.rank == size(fe, 2)
511+
fe.x
512512
else
513-
fetrm.x[:, invperm(fetrm.piv)]
513+
fe.x[:, invperm(fe.piv)]
514514
end
515515
end
516516

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
33
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
44
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
55
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6-
MixedModels = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
76
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
87
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
98
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

test/pirls.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
1-
using DataFrames
2-
using LinearAlgebra
31
using MixedModels
42
using Test
53

4+
using MixedModels: dataset
5+
6+
const fms = Dict(
7+
:cbpp => [@formula((incid/hsz) ~ 1 + period + (1|herd))],
8+
:contra => [@formula(use ~ 1+age+abs2(age)+urban+livch+(1|urbdist))],
9+
:grouseticks => [@formula(ticks ~ 1+year+ch+ (1|index) + (1|brood) + (1|location))],
10+
:verbagg => [@formula(r2 ~ 1+anger+gender+btype+situ+(1|subj)+(1|item))],
11+
)
12+
613
@testset "contra" begin
7-
contra = MixedModels.dataset(:contra)
8-
contraform = @formula(use ~ 1+age+abs2(age)+urban+livch+(1|urbdist))
9-
gm0 = fit(MixedModel, contraform, contra, Bernoulli(), fast=true);
14+
contra = dataset(:contra)
15+
gm0 = fit(MixedModel, only(fms[:contra]), contra, Bernoulli(), fast=true);
1016
@test gm0.lowerbd == zeros(1)
1117
@test isapprox(gm0.θ, [0.5720734451352923], atol=0.001)
12-
@test isapprox(deviance(gm0,true), 2361.657188518064, atol=0.001)
13-
gm1 = fit(MixedModel, contraform, contra, Bernoulli());
18+
@test isapprox(deviance(gm0), 2361.657188518064, atol=0.001)
19+
gm1 = fit(MixedModel, only(fms[:contra]), contra, Bernoulli());
1420
@test isapprox(gm1.θ, [0.573054], atol=0.005)
1521
@test lowerbd(gm1) == vcat(fill(-Inf, 7), 0.)
16-
@test isapprox(deviance(gm1,true), 2361.54575, rtol=0.00001)
22+
@test isapprox(deviance(gm1), 2361.54575, rtol=0.00001)
1723
@test isapprox(loglikelihood(gm1), -1180.77288, rtol=0.00001)
1824
@test dof(gm0) == length(gm0.β) + length(gm0.θ)
1925
@test nobs(gm0) == 1934
2026
fit!(gm0, fast=true, nAGQ=7)
2127
@test isapprox(deviance(gm0), 2360.9838, atol=0.001)
22-
gm1 = fit(MixedModel, contraform, contra, Bernoulli(), nAGQ=7)
28+
gm1 = fit(MixedModel, only(fms[:contra]), contra, Bernoulli(), nAGQ=7)
2329
@test isapprox(deviance(gm1), 2360.8760, atol=0.001)
2430
@test gm1.β == gm1.beta
2531
@test gm1.θ == gm1.theta
@@ -31,7 +37,6 @@ using Test
3137
@test length(MixedModels.rePCA(gm0)) == 1
3238
@test length(gm0.rePCA) == 1
3339
end
34-
# gm0.βθ = vcat(gm0.β, gm0.theta)
3540
# the next three values are not well defined in the optimization
3641
#@test isapprox(logdet(gm1), 75.7217, atol=0.1)
3742
#@test isapprox(sum(abs2, gm1.u[1]), 48.4747, atol=0.1)
@@ -41,8 +46,8 @@ using Test
4146
end
4247

4348
@testset "cbpp" begin
44-
cbpp = MixedModels.dataset(:cbpp)
45-
gm2 = fit(MixedModel, @formula((incid/hsz) ~ 1 + period + (1|herd)), cbpp, Binomial(), wts=float(cbpp.hsz))
49+
cbpp = dataset(:cbpp)
50+
gm2 = fit(MixedModel, only(fms[:cbpp]), cbpp, Binomial(), wts=float(cbpp.hsz))
4651
@test deviance(gm2,true) 100.09585619892968 atol=0.0001
4752
@test sum(abs2, gm2.u[1]) 9.723054788538546 atol=0.0001
4853
@test logdet(gm2) 16.90105378801136 atol=0.0001
@@ -53,22 +58,20 @@ end
5358
end
5459

5560
@testset "verbagg" begin
56-
gm3 = fit(MixedModel, @formula(r2 ~ 1+anger+gender+btype+situ+(1|subj)+(1|item)),
57-
MixedModels.dataset(:verbagg), Bernoulli())
61+
gm3 = fit(MixedModel, only(fms[:verbagg]), dataset(:verbagg), Bernoulli())
5862
@test deviance(gm3) 8151.40 rtol=1e-5
5963
@test lowerbd(gm3) == vcat(fill(-Inf, 6), zeros(2))
6064
@test fitted(gm3) == predict(gm3)
6165
# these two values are not well defined at the optimum
62-
@test isapprox(sum(x -> sum(abs2, x), gm3.u), 273.29266717430795, rtol=1e-3)
63-
@test sum(gm3.resp.devresid) 7156.547357801238 rtol=1e-4
66+
@test isapprox(sum(x -> sum(abs2, x), gm3.u), 273.29646346940785, rtol=1e-3)
67+
@test sum(gm3.resp.devresid) 7156.550941446312 rtol=1e-4
6468
end
6569

6670
@testset "grouseticks" begin
6771
center(v::AbstractVector) = v .- (sum(v) / length(v))
68-
grouseticks = MixedModels.dataset(:grouseticks)
72+
grouseticks = dataset(:grouseticks)
6973
grouseticks.ch = center(grouseticks.height)
70-
gm4 = fit(MixedModel, @formula(ticks ~ 1+year+ch+ (1|index) + (1|brood) + (1|location)),
71-
grouseticks, Poisson(), fast=true) # fails in pirls! with fast=false
74+
gm4 = fit(MixedModel, only(fms[:grouseticks]), grouseticks, Poisson(), fast=true) # fails in pirls! with fast=false
7275
@test isapprox(deviance(gm4), 851.4046, atol=0.001)
7376
# these two values are not well defined at the optimum
7477
#@test isapprox(sum(x -> sum(abs2, x), gm4.u), 196.8695297987013, atol=0.1)

0 commit comments

Comments
 (0)