Skip to content

Commit f5d7a64

Browse files
Implement a DataUpdateCallback (#287)
* Implement a data update likelihood callback * Make the likelihood accumulator optional * Make some DataUpdateCallback args into kwargs * WIP * Fix the partial observability implementation * Better type signatures * Misc * Add Fenrir to ProbNumDiffEq.jl * Add proper data likelihood tests * Split the `data_likelihoods.jl` file into separate files per method * Fix a bug * Remove the smooth=false suggestion if dense=false * Fix another broken test (again?) * Revert the ManifoldUpdate renaming * Actually remove the filtering likelihood from the dalton file * Try out DocStringExtensions.jl * Add compat entry to DocStringExtensions * Create a DataLikelihoods submodule * Remove Fenrir from the docs * Write docstrings for the fenrir and dalton likelihood and doc them * Make marginalize a bit more flexible * Update the Probabilistic Exponential Integrator citation * Remove the underscores from the likelihoods again * Add a parameter inference tutorial again * Update the doc index * Make the parameter inference example even nicer * Improve tests * Make Fenrir compatible with matrix-valued observation noise * Test matrix-valued observation noise * JuliaFormatter.jl * Add support for PSDMatrices as observation noise * Fix a MarkovKernel docstring * Implement and test update equations with non-zero observation noise * Shorten and streamline the `update!` functionality a bit * Simplify Fenrir quite a bit * Misc updates to the data update callback * Make the Fenrir code yet a bit more compact * JuliaFormatter.jl * Change the parameter inference example to partial observations * Add DiffEqCallbacks compat entry * Make the parameter inference doc code a bit nicer * Add docstrings for DataUpdateLogLikelihood and DataUpdateCallback * Polish the docs * Make sure that Fenrir ll is not inf, even if it would technically be * JuliaFormatter.jl * Faster DataUpdateCallback * Slight fenrir speed improvement * Remove some changes that I somehow introduced earlier * Add the data likelihood tests to runtests.jl * Fix the failing data likelihood tests * JuliaFormatter.jl * Test for more data likelihood observation noise types
1 parent 3bb4739 commit f5d7a64

25 files changed

+701
-114
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ version = "0.14.0"
66
[deps]
77
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
9+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
910
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
11+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1012
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
1113
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1214
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
@@ -47,7 +49,9 @@ RecipesBaseExt = "RecipesBase"
4749
[compat]
4850
ArrayAllocators = "0.3"
4951
DiffEqBase = "6.122"
52+
DiffEqCallbacks = "2.36"
5053
DiffEqDevTools = "2"
54+
DocStringExtensions = "0.9"
5155
ExponentialUtilities = "1"
5256
FastBroadcast = "0.2"
5357
FastGaussQuadrature = "0.5, 1"

docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ Bibliography = "f1be7e48-bf82-45af-a471-ae754a193061"
33
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
6-
Fenrir = "e9b4b195-f5cd-427c-8076-5358c553c37f"
76
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
87
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
98
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

docs/make.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,19 @@ makedocs(
3434
pages=[
3535
"Home" => "index.md",
3636
"Tutorials" => [
37-
"Getting Started" => "tutorials/getting_started.md"
38-
"Second Order ODEs and Energy Preservation" => "tutorials/dynamical_odes.md"
39-
"Differential Algebraic Equations" => "tutorials/dae.md"
40-
"Probabilistic Exponential Integrators" => "tutorials/exponential_integrators.md"
41-
"Parameter Inference" => "tutorials/fenrir.md"
37+
"Getting Started" => "tutorials/getting_started.md",
38+
"Second Order ODEs and Energy Preservation" => "tutorials/dynamical_odes.md",
39+
"Differential Algebraic Equations" => "tutorials/dae.md",
40+
"Probabilistic Exponential Integrators" => "tutorials/exponential_integrators.md",
41+
"Parameter Inference" => "tutorials/ode_parameter_inference.md",
4242
],
4343
"Solvers and Options" => [
4444
"solvers.md",
4545
"priors.md",
4646
"initialization.md",
4747
"diffusions.md",
4848
],
49+
"Data Likelihoods" => "likelihoods.md",
4950
"Benchmarks" => [
5051
"Multi-Language Wrapper Benchmark" => "benchmarks/multi-language-wrappers.md",
5152
"Non-stiff ODEs" => [
@@ -64,8 +65,8 @@ makedocs(
6465
],
6566
],
6667
"Internals" => [
67-
"Filtering and Smoothing" => "filtering.md"
68-
"Implementation via OrdinaryDiffEq.jl" => "implementation.md"
68+
"Filtering and Smoothing" => "filtering.md",
69+
"Implementation via OrdinaryDiffEq.jl" => "implementation.md",
6970
],
7071
"References" => "references.md",
7172
],

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Run Julia, enter `]` to bring up Julia's package manager, and add the ProbNumDif
1414

1515
```
1616
julia> ]
17-
(v1.9) pkg> add ProbNumDiffEq
17+
(v1.10) pkg> add ProbNumDiffEq
1818
```
1919

2020
## Getting Started
@@ -35,10 +35,10 @@ For a quick introduction check out the "[Solving ODEs with Probabilistic Numeric
3535
- Arbitrary precision via Julia's built-in [arbitrary precision arithmetic](https://docs.julialang.org/en/v1/manual/integers-and-floating-point-numbers/#Arbitrary-Precision-Arithmetic)
3636
- Specialized solvers for second-order ODEs (see [Second Order ODEs and Energy Preservation](@ref))
3737
- Compatible with DAEs in mass-matrix ODE form (see [Solving DAEs with Probabilistic Numerics](@ref))
38+
- Data likelihoods for parameter-inference in ODEs (see [Parameter Inference with ProbNumDiffEq.jl](@ref))
3839

3940

4041
## Related packages
4142

4243
- [probdiffeq](https://pnkraemer.github.io/probdiffeq/): Fast and feature-rich filtering-based probabilistic ODE solvers in JAX.
4344
- [ProbNum](https://probnum.readthedocs.io/en/latest/): Probabilistic numerics in Python. It has not only probabilistic ODE solvers, but also probabilistic linear solvers, Bayesian quadrature, and many filtering and smoothing implementations.
44-
- [Fenrir.jl](https://github.com/nathanaelbosch/Fenrir.jl): Parameter-inference in ODEs with probabilistic ODE solvers. This package builds on ProbNumDiffEq.jl to provide a negative marginal log-likelihood function, which can then be used with an optimizer or with MCMC for parameter inference.

docs/src/likelihoods.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Data Likelihoods
2+
3+
4+
```@docs
5+
ProbNumDiffEq.DataLikelihoods.fenrir_data_loglik
6+
ProbNumDiffEq.DataLikelihoods.dalton_data_loglik
7+
```

docs/src/refs.bib

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,16 @@ @book{sarkka19appliedsde
180180
author = {Särkkä, Simo and Solin, Arno},
181181
year = 2019,
182182
collection = {Institute of Mathematical Statistics Textbooks}
183-
}
183+
}
184+
185+
@article{wu23dalton,
186+
author = {Mohan Wu and Martin Lysy},
187+
title = {Data-Adaptive Probabilistic Likelihood Approximation for
188+
Ordinary Differential Equations},
189+
journal = {CoRR},
190+
year = 2023,
191+
url = {http://arxiv.org/abs/2306.05566},
192+
archivePrefix ={arXiv},
193+
eprint = {2306.05566},
194+
primaryClass = {stat.ML}
195+
}
Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
1-
# Parameter Inference with ProbNumDiffEq.jl and Fenrir.jl
1+
# Parameter Inference with ProbNumDiffEq.jl
22

3-
!!! note
4-
This is mostly just a copy from [the tutorial included in the Fenrir.jl documentation](https://nathanaelbosch.github.io/Fenrir.jl/stable/gettingstarted/), so have a look there too!
53

64

7-
```@example fenrir
8-
using LinearAlgebra
9-
using OrdinaryDiffEq, ProbNumDiffEq, Plots
10-
using Fenrir
11-
using Optimization, OptimizationOptimJL
12-
stack(x) = copy(reduce(hcat, x)') # convenient
13-
nothing # hide
14-
```
15-
16-
## The parameter inference problem in general
175
Let's assume we have an initial value problem (IVP)
186
```math
197
\begin{aligned}
@@ -29,12 +17,15 @@ u(t_n) = H y(t_n) + v_n, \qquad v_n \sim \mathcal{N}(0, R).
2917
The question of interest is: How can we compute the marginal likelihood ``p(\mathcal{D} \mid \theta)``?
3018
Short answer: We can't. It's intractable, because computing the true IVP solution exactly ``y(t)`` is intractable.
3119
What we can do however is compute an approximate marginal likelihood.
32-
This is what Fenrir.jl provides.
33-
For details, check out the [paper](https://proceedings.mlr.press/v162/tronarp22a.html).
20+
This is what `ProbNumDiffEq.DataLikelihoods` provides.
3421

3522
## The specific problem, in code
3623
Let's assume that the true underlying dynamics are given by a FitzHugh-Nagumo model
37-
```@example fenrir
24+
25+
```@example parameterinference
26+
using ProbNumDiffEq, LinearAlgebra, OrdinaryDiffEq, Plots
27+
Plots.theme(:default; markersize=2, markerstrokewidth=0.1)
28+
3829
function f(du, u, p, t)
3930
a, b, c = p
4031
du[1] = c*(u[1] - u[1]^3/3 + u[2])
@@ -46,30 +37,43 @@ p = (0.2, 0.2, 3.0)
4637
true_prob = ODEProblem(f, u0, tspan, p)
4738
```
4839
from which we generate some artificial noisy data
49-
```@example fenrir
40+
```@example parameterinference
5041
true_sol = solve(true_prob, Vern9(), abstol=1e-10, reltol=1e-10)
5142
5243
times = 1:0.5:20
53-
observation_noise_var = 1e-1
54-
odedata = [true_sol(t) .+ sqrt(observation_noise_var) * randn(length(u0)) for t in times]
44+
σ = 1e-1
45+
H = [1 0;]
46+
odedata = [H*true_sol(t) .+ σ * randn() for t in times]
5547
5648
plot(true_sol, color=:black, linestyle=:dash, label=["True Solution" ""])
57-
scatter!(times, stack(odedata), markersize=2, markerstrokewidth=0.1, color=1, label=["Noisy Data" ""])
49+
scatter!(times, stack(odedata)', color=1, label=["Noisy Data" ""])
5850
```
5951
Our goal is then to recover the true parameter `p` (and thus also the true trajectory plotted above) the noisy data.
6052

6153
## Computing the negative log-likelihood
62-
To do parameter inference - be it maximum-likelihod, maximum a posteriori, or full Bayesian inference with MCMC - we need to evaluate the likelihood of given a parameter estimate ``\theta_\text{est}``.
63-
This is exactly what Fenrir.jl's [`fenrir_nll`](https://nathanaelbosch.github.io/Fenrir.jl/stable/#Fenrir.fenrir_nll) provides:
64-
```@example fenrir
65-
p_est = (0.1, 0.1, 2.0)
66-
prob = remake(true_prob, p=p_est)
54+
To do parameter inference - be it maximum-likelihod, maximum a posteriori, or full Bayesian inference with MCMC - we need to evaluate the likelihood of given a parameter estimate ``\theta_\text{est}``, which corresponds to the probability of the data under the trajectory returned by the ODE solver
55+
```@example parameterinference
56+
θ_est = (0.1, 0.1, 2.0)
57+
prob = remake(true_prob, p=θ_est)
58+
plot(true_sol, color=:black, linestyle=:dash, label=["True Solution" ""])
59+
scatter!(times, stack(odedata)', color=1, label=["Noisy Data" ""])
60+
sol = solve(prob, EK1(), adaptive=false, dt=1e-1)
61+
plot!(sol, color=2, label=["Numerical solution for θ_est" ""])
62+
```
63+
This quantity can be computed in multiple ways; see
64+
[Data Likelihoods](@ref).
65+
Here we use
66+
[`ProbNumDiffEq.DataLikelihoods.fenrir_data_loglik`](@ref):
67+
```@example parameterinference
68+
using ProbNumDiffEq.DataLikelihoods
69+
6770
data = (t=times, u=odedata)
68-
κ² = 1e10
69-
nll, _, _ = fenrir_nll(prob, data, observation_noise_var, κ²; dt=1e-1)
70-
nll
71+
nll = -fenrir_data_loglik(
72+
prob, EK1(smooth=true);
73+
data, observation_noise_cov=σ^2, observation_matrix=H,
74+
adaptive=false, dt=1e-1)
7175
```
72-
This is the negative marginal log-likelihood of the parameter `p_est`.
76+
This is the negative marginal log-likelihood of the parameter `θ_est`.
7377
You can use it as any other NLL: Optimize it to compute maximum-likelihood estimates or MAPs, or plug it into MCMC to sample from the posterior.
7478
In our paper [tronarp22fenrir](@cite) we compute MLEs by pairing Fenrir with [Optimization.jl](http://optimization.sciml.ai/stable/) and [ForwardDiff.jl](https://juliadiff.org/ForwardDiff.jl/stable/).
7579
Let's quickly explore how to do this next.
@@ -80,23 +84,29 @@ Let's quickly explore how to do this next.
8084
To compute a maximum-likelihood estimate (MLE), we just need to maximize ``\theta \to p(\mathcal{D} \mid \theta)`` - that is, minimize the `nll` from above.
8185
We use [Optimization.jl](https://docs.sciml.ai/Optimization/stable/) for this.
8286
First, define a loss function and create an `OptimizationProblem`
83-
```@example fenrir
87+
```@example parameterinference
88+
using Optimization, OptimizationOptimJL
89+
8490
function loss(x, _)
8591
ode_params = x[begin:end-1]
8692
prob = remake(true_prob, p=ode_params)
87-
κ² = exp(x[end]) # the diffusion parameter of the EK1
88-
return fenrir_nll(prob, data, observation_noise_var, κ²; dt=1e-1)
93+
κ² = exp(x[end]) # we also optimize the diffusion parameter of the EK1
94+
return -fenrir_data_loglik(
95+
prob, EK1(smooth=true, diffusionmodel=FixedDiffusion(κ², false));
96+
data, observation_noise_cov=σ^2, observation_matrix=H,
97+
adaptive=false, dt=1e-1
98+
)
8999
end
90100
91101
fun = OptimizationFunction(loss, Optimization.AutoForwardDiff())
92102
optprob = OptimizationProblem(
93-
fun, [p_est..., 1e0];
103+
fun, [θ_est..., 1e0];
94104
lb=[0.0, 0.0, 0.0, -10], ub=[1.0, 1.0, 5.0, 20] # lower and upper bounds
95105
)
96106
```
97107

98108
Then, just `solve` it! Here we use LBFGS:
99-
```@example fenrir
109+
```@example parameterinference
100110
optsol = solve(optprob, LBFGS())
101111
p_mle = optsol.u[1:3]
102112
p_mle # hide
@@ -105,21 +115,27 @@ p_mle # hide
105115
Success! The computed MLE is quite close to the true parameter which we used to generate the data.
106116
As a final step, let's plot the true solution, the data, and the result of the MLE:
107117

108-
```@example fenrir
118+
```@example parameterinference
109119
plot(true_sol, color=:black, linestyle=:dash, label=["True Solution" ""])
110-
scatter!(times, stack(odedata), markersize=2, markerstrokewidth=0.1, color=1, label=["Noisy Data" ""])
120+
scatter!(times, stack(odedata)', color=1, label=["Noisy Data" ""])
111121
mle_sol = solve(remake(true_prob, p=p_mle), EK1())
112122
plot!(mle_sol, color=3, label=["MLE-parameter Solution" ""])
113123
```
114124

115125
Looks good!
116126

117127

118-
### Reference
128+
## API Documentation
129+
130+
For more details, see the API documentation of `ProbNumDiffEq.DataLikelihoods` at [Data Likelihoods](@ref).
131+
132+
133+
### References
119134

120135
```@bibliography
121136
Pages = []
122137
Canonical = false
123138
124139
tronarp22fenrir
140+
wu23dalton
125141
```

src/ProbNumDiffEq.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using LinearAlgebra
88
import LinearAlgebra: mul!
99
import Statistics: mean, var, std, cov
1010
using Printf
11+
using DocStringExtensions
1112

1213
using Reexport
1314
@reexport using DiffEqBase
@@ -31,6 +32,7 @@ using ArrayAllocators
3132
using FiniteHorizonGramians
3233
using FillArrays
3334
using MatrixEquations
35+
using DiffEqCallbacks
3436

3537
@reexport using GaussianDistributions
3638

@@ -95,8 +97,18 @@ if !isdefined(Base, :get_extension)
9597
include("../ext/DiffEqDevToolsExt.jl")
9698
end
9799

98-
include("callbacks.jl")
100+
include("callbacks/manifoldupdate.jl")
99101
export ManifoldUpdate
102+
include("callbacks/dataupdate.jl")
103+
export DataUpdateLogLikelihood, DataUpdateCallback
104+
105+
include("data_likelihoods/dalton.jl")
106+
include("data_likelihoods/filtering.jl")
107+
include("data_likelihoods/fenrir.jl")
108+
module DataLikelihoods
109+
import ..ProbNumDiffEq: dalton_data_loglik, filtering_data_loglik, fenrir_data_loglik
110+
export dalton_data_loglik, filtering_data_loglik, fenrir_data_loglik
111+
end
100112

101113
include("precompile.jl")
102114

src/algorithms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ julia> solve(prob, RosenbrockExpEK())
187187
```
188188
189189
# Reference
190-
* [bosch23expint](@cite) Bosch et al, "Probabilistic Exponential Integrators", arXiv (2021)
190+
* [bosch23expint](@cite) Bosch et al, "Probabilistic Exponential Integrators", NeurIPS (2023)
191191
"""
192192
RosenbrockExpEK(; order=3, kwargs...) =
193193
EK1(; prior=IOUP(order, update_rate_parameter=true), kwargs...)

src/caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function OrdinaryDiffEq.alg_cache(
168168
copy!(x0.Σ, apply_diffusion(x0.Σ, initdiff))
169169

170170
# Measurement model related things
171-
R = factorized_similar(FAC, d, d)
171+
R = nothing # factorized_similar(FAC, d, d)
172172
H = factorized_similar(FAC, d, D)
173173
v = similar(Array{uElType}, d)
174174
S = PSDMatrix(factorized_zeros(FAC, D, d))

0 commit comments

Comments
 (0)