Skip to content

Commit b943536

Browse files
committed
tutorail: save with PBM from DoubleMM
1 parent 2c4d338 commit b943536

File tree

10 files changed

+106
-53
lines changed

10 files changed

+106
-53
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@ docs/src/**/*_files/
1414
docs/src/**/*.html
1515
docs/src/**/*.ipynb
1616
docs/src/**/*Manifest.toml
17-
docs/src/tutorials/intermediate/*

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using HybridVariationalInference
2+
import HybridVariationalInference.DoubleMM
23
using Documenter
34

45
DocMeta.setdocmeta!(HybridVariationalInference, :DocTestSetup, :(using HybridVariationalInference); recursive=true)
56

67
makedocs(;
8+
#modules=[HybridVariationalInference, HybridVariationalInference.DoubleMM],
79
modules=[HybridVariationalInference],
810
authors="Thomas Wutzler <[email protected]> and contributors",
911
sitename="HybridVariationalInference.jl",

docs/src/reference/reference_public.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ i.e. the docstrings.
1111

1212
``` @autodocs
1313
Modules = [
14-
HybridVariationalInference,
14+
HybridVariationalInference, HybridVariationalInference.DoubleMM
1515
]
1616
Private = false
1717
```

docs/src/tutorials/basic_cpu.md

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ transM = Stacked(HVI.Exp(), HVI.Exp())
9696

9797
Parameter transformations are specified using the `Bijectors` package.
9898
Because, `Bijectors.elementwise(exp)`, has problems with automatic differentiation (AD)
99-
on GPU, we use the non-exported [`Exp`]() wrapper inside `Bijectors.Stacked`.
99+
on GPU, we use the public but non-exported [`Exp`](@ref) wrapper inside `Bijectors.Stacked`.
100100

101101
### Prior information on parameters at constrained scale
102102

@@ -125,8 +125,8 @@ Here, we use synthetic data generated by the package.
125125

126126
``` julia
127127
rng = StableRNG(111)
128-
scenario = Val((:omit_r0, :covarK2, ))
129-
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario)
128+
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(
129+
rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))
130130
```
131131

132132
Lets look at them.
@@ -335,16 +335,25 @@ epochs of the optimization.
335335
## Saving the results
336336

337337
Extracting useful information from the optimized HybridProblem is covered
338-
in the following tutorial. XXLink
339-
338+
in the following [Inspect results of fitted problem](@ref) tutorial.
340339
In order to use the results from this tutorial in other tutorials,
341340
the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file.
342341

342+
Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
343+
of the PBM in module `DoubleMM` rather than
344+
module `Main` to allow for easier reloading with JLD2.
345+
346+
``` julia
347+
f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
348+
f_allsites = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1])
349+
probo2 = HybridProblem(probo; f_batch, f_allsites)
350+
```
351+
343352
``` julia
344353
using JLD2
345354
fname = "intermediate/basic_cpu_results.jld2"
346355
mkpath("intermediate")
347-
if probo isa AbstractHybridProblem # do not save on failure above
348-
jldsave(fname, false, IOStream; probo, interpreters)
356+
if probo2 isa AbstractHybridProblem # do not save on failure above
357+
jldsave(fname, false, IOStream; probo=probo2, interpreters)
349358
end
350359
```

docs/src/tutorials/basic_cpu.qmd

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using DistributionFits
3131

3232
Next, specify many moving parts of the Hybrid variational inference (HVI)
3333

34-
## The process-based model
34+
## The process-based model
3535
The example process based model (PBM) predicts a double-monod constrained rate
3636
for different substrate concentrations, `S1`, and `S2`.
3737

@@ -101,7 +101,7 @@ transM = Stacked(HVI.Exp(), HVI.Exp())
101101

102102
Parameter transformations are specified using the `Bijectors` package.
103103
Because, `Bijectors.elementwise(exp)`, has problems with automatic differentiation (AD)
104-
on GPU, we use the non-exported [`Exp`]() wrapper inside `Bijectors.Stacked`.
104+
on GPU, we use the public but non-exported [`Exp`](@ref) wrapper inside `Bijectors.Stacked`.
105105

106106
### Prior information on parameters at constrained scale
107107

@@ -130,16 +130,16 @@ Here, we use synthetic data generated by the package.
130130

131131
```{julia}
132132
rng = StableRNG(111)
133-
scenario = Val((:omit_r0, :covarK2, ))
134-
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario)
133+
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(
134+
rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))
135135
```
136136

137137
```{julia}
138138
#| echo: false
139139
#| eval: false
140140
() -> begin
141141
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) =
142-
gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario)
142+
gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))
143143
end
144144
```
145145

@@ -265,7 +265,7 @@ y1 = f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2]
265265
#using Cthulhu
266266
#@descend_code_warntype f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))
267267
prob0 = HVI.DoubleMM.DoubleMMCase()
268-
f_batch0 = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = false)
268+
f_batch0 = get_hybridproblem_PBmodel(prob0; use_all_sites = false)
269269
y1f = f_batch0(θP, θMs, x_batch)[2]
270270
y1 .- y1f # equal
271271
end
@@ -352,17 +352,26 @@ epochs of the optimization.
352352

353353
## Saving the results
354354
Extracting useful information from the optimized HybridProblem is covered
355-
in the following tutorial. XXLink
356-
357-
In order to use the results from this tutorial in other tutorials,
355+
in the following [Inspect results of fitted problem](@ref) tutorial.
356+
In order to use the results from this tutorial in other tutorials,
358357
the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file.
359358

359+
Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
360+
of the PBM in module `DoubleMM` rather than
361+
module `Main` to allow for easier reloading with JLD2.
362+
363+
```{julia}
364+
f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
365+
f_allsites = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1])
366+
probo2 = HybridProblem(probo; f_batch, f_allsites)
367+
```
368+
360369
```{julia}
361370
using JLD2
362371
fname = "intermediate/basic_cpu_results.jld2"
363372
mkpath("intermediate")
364-
if probo isa AbstractHybridProblem # do not save on failure above
365-
jldsave(fname, false, IOStream; probo, interpreters)
373+
if probo2 isa AbstractHybridProblem # do not save on failure above
374+
jldsave(fname, false, IOStream; probo=probo2, interpreters)
366375
end
367376
```
368377

docs/src/tutorials/inspect_results.md

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,8 @@ using CairoMakie
1818
using PairPlots # scatterplot matrices
1919
```
2020

21-
After redefinig the process-based model (currently JLD2 cannot save functions),
22-
the previously optimized Problem can be loaded.
23-
24-
``` julia
25-
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
26-
# extract several covariates from xP
27-
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
28-
S1 = (CA.getdata(xPc[:S1,:])::ST)
29-
S2 = (CA.getdata(xPc[:S2,:])::ST)
30-
#
31-
# extract the parameters as row-repeated vectors
32-
n_obs = size(S1, 1)
33-
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
34-
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
35-
p1 = CA.getdata(θc[:, par]) ::VT
36-
repeat(p1', n_obs) # matrix: same for each concentration row in S1
37-
end
38-
#
39-
# each variable is a matrix (n_obs x n_site)
40-
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
41-
end
42-
```
21+
This tutorial uses the fitted object saved in the
22+
[Basic workflow without GPU](@ref) tutorial.
4323

4424
``` julia
4525
fname = "intermediate/basic_cpu_results.jld2"
@@ -77,7 +57,7 @@ They are ComponentArrays with the parameter dimension names that can be used:
7757
θsMs[1,:r1,:] # sample of r1 of the first site
7858
```
7959

80-
### Corner plots
60+
## Corner plots
8161

8262
The relation between different variables can be well inspected by
8363
scatterplot matrices, also called corner plots or pair plots.
@@ -92,7 +72,7 @@ i_site = 1
9272
plt = pairplot(θ1_nt)
9373
```
9474

95-
![](inspect_results_files/figure-commonmark/cell-9-output-1.png)
75+
![](inspect_results_files/figure-commonmark/cell-8-output-1.png)
9676

9777
The plot shows that parameters for the first site, *K*₁ and *r*₁, are correlated,
9878
but that we did not model correlation with the global parameter, *K*₂.
@@ -101,7 +81,7 @@ Note that this plots shows only the first out of 800 sites.
10181
HVI estimated a 1602-dimensional posterior distribution including
10282
covariances among parameters.
10383

104-
### Expected values and marginal variances
84+
## Expected values and marginal variances
10585

10686
Lets look at how the estimated uncertainty of a site parameter changes with
10787
its expected value.
@@ -115,7 +95,7 @@ scatter!(ax, θmean, θsd)
11595
fig
11696
```
11797

118-
![](inspect_results_files/figure-commonmark/cell-11-output-1.png)
98+
![](inspect_results_files/figure-commonmark/cell-10-output-1.png)
11999

120100
We see that *K*₁ across sites ranges from about 0.18 to 0.25, and that
121101
its estimated uncertainty is about 0.034, slightly decreasing with the
@@ -157,7 +137,7 @@ scatter!(ax, ymean, ysd)
157137
fig
158138
```
159139

160-
![](inspect_results_files/figure-commonmark/cell-14-output-1.png)
140+
![](inspect_results_files/figure-commonmark/cell-13-output-1.png)
161141

162142
We see that observed values for associated substrate concentrations range about from
163143
0.51 to 0.59 with an estimated standard deviation around 0.005 that decreases

docs/src/tutorials/inspect_results.qmd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ using CairoMakie
2929
using PairPlots # scatterplot matrices
3030
```
3131

32-
After redefinig the process-based model (currently JLD2 cannot save functions),
33-
the previously optimized Problem can be loaded.
34-
35-
{{< include _pbm_matrix.qmd >}}
32+
This tutorial uses the fitted object saved at the end of the
33+
[Basic workflow without GPU](@ref) tutorial.
3634

3735
```{julia}
3836
fname = "intermediate/basic_cpu_results.jld2"
@@ -43,6 +41,8 @@ probo, interpreters = load(fname, "probo", "interpreters");
4341
```{julia}
4442
#| eval: false
4543
#| echo: false
44+
# not necessary any more with DoubleMM.f_doubleMM_sites
45+
# {{< include _pbm_matrix.qmd >}}
4646
# outside notebook, need to reset ModelApplicator, due to fθ defined in Notebook module
4747
#θFix = CA.ComponentVector{eltype(probo.θP)}(r0=0.3)
4848
θFix = CA.ComponentVector{eltype(probo.θP)}(
8 Bytes
Binary file not shown.

src/DoubleMM/DoubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using MLDataDevices
1414
import GPUArraysCore # used in conditional breakpoints
1515
import StableRNGs
1616

17-
export f_doubleMM, xP_S1, xP_S2
17+
export f_doubleMM, f_doubleMM_sites, xP_S1, xP_S2
1818
include("f_doubleMM.jl")
1919

2020

src/DoubleMM/f_doubleMM.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,29 @@ int_xP1 = ComponentArrayInterpreter(CA.ComponentVector(S1 = xP_S1, S2 = xP_S2))
2121

2222
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
2323

24+
"""
25+
f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
26+
27+
Example process based model (PBM) predicts a double-monod constrained rate
28+
for different substrate concentration vectors, `x.S1`, and `x.S2` for a single site.
29+
θc is a ComponentVector with scalar parameters as components: `r0`, `r1`, `K1`, and `K2`
30+
31+
It predicts a rate for each entry in concentrations:
32+
`y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)`.
33+
34+
It is defined as
35+
```julia
36+
function f_doubleMM(θc::ComponentVector{ET}, x) where ET
37+
# extract parameters not depending on order, i.e whether they are in θP or θM
38+
# r0 = θc[:r0]
39+
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
40+
getdata(θc[par])::ET
41+
end
42+
y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
43+
return (y)
44+
end
45+
```
46+
"""
2447
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
2548
# extract parameters not depending on order, i.e whether they are in θP or θM
2649
GPUArraysCore.allowscalar() do # index to scalar parameter in parameter vector
@@ -40,13 +63,44 @@ function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
4063
end
4164
end
4265

66+
"""
67+
f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
68+
69+
Example process based model (PBM) that predicts for a batch of sites.
70+
71+
Arguments
72+
- `θc`: parameters with one row per site and symbolic column index
73+
- `xPc`: model drivers with one column per site, and symbolic row index
74+
75+
Returns a matrix `(n_obs x n_site)` of predictions.
76+
77+
```julia
78+
function f_doubleMM_sites(θc::ComponentMatrix, xPc::ComponentMatrix)
79+
# extract several covariates from xP
80+
ST = typeof(getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
81+
S1 = (getdata(xPc[:S1,:])::ST)
82+
S2 = (getdata(xPc[:S2,:])::ST)
83+
#
84+
# extract the parameters as vectors that are row-repeated into a matrix
85+
n_obs = size(S1, 1)
86+
VT = typeof(getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
87+
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
88+
p1 = getdata(θc[:, par]) ::VT
89+
repeat(p1', n_obs) # matrix: same for each concentration row in S1
90+
end
91+
#
92+
# each variable is a matrix (n_obs x n_site)
93+
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
94+
end
95+
```
96+
"""
4397
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
4498
# extract several covariates from xP
4599
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
46100
S1 = (CA.getdata(xPc[:S1,:])::ST)
47101
S2 = (CA.getdata(xPc[:S2,:])::ST)
48102
#
49-
# extract the parameters as row-repeated vectors
103+
# extract the parameters as vectors that are row-repeated into a matrix
50104
n_obs = size(S1, 1)
51105
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
52106
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par

0 commit comments

Comments
 (0)