Skip to content

Commit 2974509

Browse files
authored
FiniteDiff.jl support (#842)
* FiniteDiff.jl support * NEWS link * format * windows. * rename shared example block * using currying method
1 parent 446d036 commit 2974509

File tree

8 files changed

+127
-14
lines changed

8 files changed

+127
-14
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
MixedModels v4.38.0 Release Notes
2+
==============================
3+
- Experimental support for evaluating `FiniteDiff.finite_difference_gradient` and `FiniteDiff.finite_difference_hessian of the objective of a fitted `LinearMixedModel`. [#842]
4+
15
MixedModels v4.37.0 Release Notes
26
==============================
37
- Experimental support for evaluating `ForwardDiff.gradient` and `ForwardDiff.hessian` of the objective of a fitted `LinearMixedModel`. [#841]
@@ -652,3 +656,4 @@ Package dependencies
652656
[#829]: https://github.com/JuliaStats/MixedModels.jl/issues/829
653657
[#836]: https://github.com/JuliaStats/MixedModels.jl/issues/836
654658
[#841]: https://github.com/JuliaStats/MixedModels.jl/issues/841
659+
[#842]: https://github.com/JuliaStats/MixedModels.jl/issues/842

Project.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MixedModels"
22
uuid = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
33
author = ["Phillip Alday <[email protected]>", "Douglas Bates <[email protected]>"]
4-
version = "4.37.0"
4+
version = "4.38.0"
55

66
[deps]
77
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
@@ -31,12 +31,14 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3131
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3232

3333
[weakdeps]
34-
PRIMA = "0a7d04aa-8ac2-47b3-b7a7-9dbd6ad661ed"
34+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
3535
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
36+
PRIMA = "0a7d04aa-8ac2-47b3-b7a7-9dbd6ad661ed"
3637

3738
[extensions]
38-
MixedModelsPRIMAExt = ["PRIMA"]
39+
MixedModelsFiniteDiffExt = ["FiniteDiff"]
3940
MixedModelsForwardDiffExt = ["ForwardDiff"]
41+
MixedModelsPRIMAExt = ["PRIMA"]
4042

4143
[compat]
4244
Aqua = "0.8"
@@ -47,6 +49,7 @@ DataAPI = "1"
4749
DataFrames = "1"
4850
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
4951
ExplicitImports = "1.3"
52+
FiniteDiff = "2.27"
5053
ForwardDiff = "1"
5154
GLM = "1.8.2"
5255
InteractiveUtils = "1"
@@ -80,6 +83,7 @@ julia = "1.10"
8083
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
8184
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
8285
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
86+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
8387
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8488
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
8589
PRIMA = "0a7d04aa-8ac2-47b3-b7a7-9dbd6ad661ed"
@@ -89,4 +93,4 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
8993
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9094

9195
[targets]
92-
test = ["Aqua", "DataFrames", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "PRIMA", "RegressionFormulae", "StableRNGs", "Suppressor", "Test"]
96+
test = ["Aqua", "DataFrames", "ExplicitImports", "FiniteDiff", "ForwardDiff", "InteractiveUtils", "PRIMA", "RegressionFormulae", "StableRNGs", "Suppressor", "Test"]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
44
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
55
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
66
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
7+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
FreqTables = "da1fdf0e-e0ff-5433-a45f-9bb5ff651cb1"
910
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"

docs/make.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Documenter
22
using MixedModels
3+
using FiniteDiff
34
using ForwardDiff
45
using StatsAPI
56
using StatsBase
@@ -8,6 +9,8 @@ makedocs(;
89
sitename="MixedModels",
910
format=Documenter.HTML(; size_threshold=500_000, size_threshold_warn=250_000),
1011
doctest=true,
12+
# pagesonly=true,
13+
# warnonly=true,
1114
pages=[
1215
"index.md",
1316
"constructors.md",

docs/src/derivatives.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
# Gradient and Hessian via ForwardDiff.jl
1+
# Gradient and Hessian computation
2+
3+
Experimental support for computing the gradient and the Hessian of the objective function (i.e. negative twice the profiled log likelihood) via ForwardDiff.jl and FiniteDiff.jl are provided as package extensions.
4+
5+
## via ForwardDiff.jl
26

3-
Experimental support for computing the gradient and the Hessian of the objective function (i.e. negative twice the profiled log likelihood) via ForwardDiff.jl are provided as a package extension.
47

58
The core functionality is provided by defining appropriate methods for `ForwardDiff.gradient` and `ForwardDiff.hessian`:
69

@@ -9,34 +12,52 @@ ForwardDiff.gradient(::LinearMixedModel{T}, ::Vector{T}) where {T}
912
ForwardDiff.hessian(::LinearMixedModel{T}, ::Vector{T}) where {T}
1013
```
1114

12-
## Exact zero at optimum for trivial models
15+
### Exact zero at optimum for trivial models
1316

14-
```@example ForwardDiff
17+
```@example Derivatives
1518
using MixedModels, ForwardDiff
1619
using DisplayAs # hide
1720
fm1 = lmm(@formula(yield ~ 1 + (1|batch)), MixedModels.dataset(:dyestuff2))
1821
DisplayAs.Text(ans) # hide
1922
```
2023

21-
```@example ForwardDiff
24+
```@example Derivatives
2225
ForwardDiff.gradient(fm1)
2326
```
2427

25-
```@example ForwardDiff
28+
```@example Derivatives
2629
ForwardDiff.hessian(fm1)
2730
```
2831

29-
## Approximate zero at optimum for non trivial models
32+
### Approximate zero at optimum for non trivial models
3033

31-
```@example ForwardDiff
34+
```@example Derivatives
3235
fm2 = lmm(@formula(reaction ~ 1 + days + (1+days|subj)), MixedModels.dataset(:sleepstudy))
3336
DisplayAs.Text(ans) # hide
3437
```
3538

36-
```@example ForwardDiff
39+
```@example Derivatives
3740
ForwardDiff.gradient(fm2)
3841
```
3942

40-
```@example ForwardDiff
43+
```@example Derivatives
4144
ForwardDiff.hessian(fm2)
4245
```
46+
47+
## via FiniteDiff.jl
48+
49+
The core functionality is provided by defining appropriate methods for `FiniteDiff.finite_difference_gradient` and `FiniteDiff.finite_difference_hessian`:
50+
51+
```@docs
52+
FiniteDiff.finite_difference_gradient(::LinearMixedModel{T}, ::Vector{T}) where {T}
53+
FiniteDiff.finite_difference_hessian(::LinearMixedModel{T}, ::Vector{T}) where {T}
54+
```
55+
56+
```@example Derivatives
57+
using FiniteDiff
58+
FiniteDiff.finite_difference_gradient(fm2)
59+
```
60+
61+
```@example Derivatives
62+
FiniteDiff.finite_difference_hessian(fm2)
63+
```

ext/MixedModelsFiniteDiffExt.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
module MixedModelsFiniteDiffExt
2+
3+
using MixedModels: LinearMixedModel, objective!, updateL!, setθ!
4+
using FiniteDiff: FiniteDiff, finite_difference_gradient, finite_difference_hessian
5+
6+
const FINITEDIFF = """
7+
!!! warning "FiniteDiff.jl support is experimental."
8+
Compatibility with FiniteDiff.jl is experimental. The precise structure,
9+
including function names and method definitions, is subject to
10+
change without being considered a breaking change. In particular,
11+
the exact set of parameters included is subject to change. The
12+
θ parameter is always included, but whether σ and/or the fixed effects
13+
should be included is currently still being decided.
14+
"""
15+
16+
"""
17+
FiniteDiff.finite_difference_gradient(model::LinearMixedModel, args...; kwargs...)
18+
19+
Evaluate the gradient of the objective function at the currently fitted parameter
20+
values.
21+
22+
$(FINITEDIFF)
23+
"""
24+
function FiniteDiff.finite_difference_gradient(
25+
model::LinearMixedModel{T}, θ::Vector{T}=model.θ, args...; kwargs...
26+
) where {T}
27+
local grad
28+
try
29+
grad = finite_difference_gradient(objective!(model), θ, args...; kwargs...)
30+
finally
31+
updateL!(setθ!(model, θ))
32+
end
33+
34+
return grad
35+
end
36+
37+
"""
38+
FiniteDiff.finite_difference_hessian(model::LinearMixedModel, args...; kwargs...)
39+
40+
Evaluate the Hessian of the objective function at the currently fitted parameter
41+
values.
42+
43+
$(FINITEDIFF)
44+
"""
45+
function FiniteDiff.finite_difference_hessian(
46+
model::LinearMixedModel{T}, θ::Vector{T}=model.θ, args...; kwargs...
47+
) where {T}
48+
local hess
49+
try
50+
hess = finite_difference_hessian(objective!(model), θ, args...; kwargs...)
51+
finally
52+
updateL!(setθ!(model, θ))
53+
end
54+
55+
return hess
56+
end
57+
58+
end # module

test/finitediff.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using MixedModels, FiniteDiff, Test
2+
include("modelcache.jl")
3+
4+
fm1 = only(models(:dyestuff2))
5+
@test FiniteDiff.finite_difference_gradient(fm1) [0.0]
6+
@test FiniteDiff.finite_difference_hessian(fm1) [28.7686] atol=0.0001
7+
8+
fm2 = last(models(:sleepstudy))
9+
@test FiniteDiff.finite_difference_gradient(fm2) [0.0, 0.0, 0.0] atol=0.005
10+
11+
# REML and zerocorr
12+
fm3 = lmm(@formula(reaction ~ 1 + days + zerocorr(1+days|subj)), MixedModels.dataset(:sleepstudy); REML=true)
13+
@test FiniteDiff.finite_difference_gradient(fm3) [0.0,0.0] atol=0.001
14+
15+
# crossed random effects
16+
if !Sys.iswindows() # this doesn't meet even the very loose tolerance on windows
17+
fm4 = last(models(:kb07))
18+
g = FiniteDiff.finite_difference_gradient(fm4)
19+
@test g zero(g) atol=0.1
20+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ include("sigma.jl")
5050

5151
@testset "PRIMA" include("prima.jl")
5252
@testset "ForwardDiff" include("forwarddiff.jl")
53+
@testset "FiniteDiff" include("finitediff.jl")

0 commit comments

Comments
 (0)