Skip to content

Commit e24ff54

Browse files
committed
docu: tutorial on inspecting fit
1 parent a7959b9 commit e24ff54

File tree

10 files changed

+372
-33
lines changed

10 files changed

+372
-33
lines changed

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true)
572572
trans_mP=StackedArray(transP, size(ζsP, 2))
573573
trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3))
574574
θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs)
575-
y = apply_process_model(θsP, θsMs, f, xP)
575+
y = apply_process_model(θsP, θsMs, f, xP)
576576
#(; y, θsP, θsMs) = HVI.apply_f_trans(ζsP, ζsMs, f_allsites, xP; transP, transM);
577577
(y_hmc, θsP_hmc, θsMs_hmc) = (; y, θsP, θsMs);
578578

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ makedocs(;
1717
"Problem" => "problem.md",
1818
"Tutorials" => [
1919
"Basic workflow" => "tutorials/basic_cpu.md",
20-
"Test quarto markdown" => "tutorials/test1.md",
20+
"Inspect results" => "tutorials/inspect_results.md",
21+
#"Test quarto markdown" => "tutorials/test1.md",
2122
],
2223
"How to" => [
2324
#".. model independent parameters" => "tutorials/how_to_guides/blocks_corr_site.md",
2425
#".. model site-global corr" => "tutorials/how_to_guides/corr_site_global.md",
2526
],
2627
"Explanation" => [
27-
"Theory" => "explanation/theory_hvi.md",
28+
#"Theory" => "explanation/theory_hvi.md", TODO activate when paper is published
2829
],
2930
"Reference" => [
3031
"Public" => "reference/reference_public.md",

docs/src/tutorials/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
[deps]
22
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
3+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
34
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
45
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
56
DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
67
HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64"
78
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
89
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
910
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
11+
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
1012
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1113
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1214
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

docs/src/tutorials/_pbm_matrix.qmd

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
```{julia}
2+
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
3+
# extract several covariates from xP
4+
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
5+
S1 = (CA.getdata(xPc[:S1,:])::ST)
6+
S2 = (CA.getdata(xPc[:S2,:])::ST)
7+
#
8+
# extract the parameters as row-repeated vectors
9+
n_obs = size(S1, 1)
10+
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
11+
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
12+
p1 = CA.getdata(θc[:, par]) ::VT
13+
repeat(p1', n_obs) # matrix: same for each concentration row in S1
14+
end
15+
#
16+
# each variable is a matrix (n_obs x n_site)
17+
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
18+
end
19+
```
20+

docs/src/tutorials/basic_cpu.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,9 @@ import Zygote
247247

248248
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
249249

250-
(; probo, interpreters) = solve(prob, solver; scenario, rng,
250+
(; probo, interpreters) = solve(prob, solver; rng,
251251
callback = callback_loss(100), # output during fitting
252252
epochs = 2,
253-
gdev = identity, # do not use GPU, here
254253
);
255254
```
256255

@@ -326,10 +325,9 @@ As a test of the new applicator, the results are refined by running a few more
326325
epochs of the optimization.
327326

328327
``` julia
329-
(; probo) = solve(probo_sites, solver; scenario, rng,
328+
(; probo) = solve(probo_sites, solver; rng,
330329
callback = callback_loss(100), # output during fitting
331330
epochs = 10,
332-
gdev = identity, # do not use GPU, here
333331
#is_inferred = Val(true), # activate type-checks
334332
);
335333
```

docs/src/tutorials/basic_cpu.qmd

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,9 @@ import Zygote
282282
283283
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
284284
285-
(; probo, interpreters) = solve(prob, solver; scenario, rng,
285+
(; probo, interpreters) = solve(prob, solver; rng,
286286
callback = callback_loss(100), # output during fitting
287287
epochs = 2,
288-
gdev = identity, # do not use GPU, here
289288
);
290289
```
291290

@@ -317,25 +316,7 @@ For the parameters, one row corresponds to
317316
one site. For the drivers and predictions, one column corresponds to one site.
318317

319318

320-
```{julia}
321-
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
322-
# extract several covariates from xP
323-
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
324-
S1 = (CA.getdata(xPc[:S1,:])::ST)
325-
S2 = (CA.getdata(xPc[:S2,:])::ST)
326-
#
327-
# extract the parameters as row-repeated vectors
328-
n_obs = size(S1, 1)
329-
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
330-
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
331-
p1 = CA.getdata(θc[:, par]) ::VT
332-
repeat(p1', n_obs) # matrix: same for each concentration row in S1
333-
end
334-
#
335-
# each variable is a matrix (n_obs x n_site)
336-
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
337-
end
338-
```
319+
{{< include _pbm_matrix.qmd >}}
339320

340321
Again, the function should not rely on the order of parameters but use symbolic indexing
341322
to extract the parameter vectors. For type stability of this symbolic indexing,
@@ -362,10 +343,9 @@ As a test of the new applicator, the results are refined by running a few more
362343
epochs of the optimization.
363344

364345
```{julia}
365-
(; probo) = solve(probo_sites, solver; scenario, rng,
346+
(; probo) = solve(probo_sites, solver; rng,
366347
callback = callback_loss(100), # output during fitting
367348
epochs = 10,
368-
gdev = identity, # do not use GPU, here
369349
#is_inferred = Val(true), # activate type-checks
370350
);
371351
```
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Inspect results of fitted problem
2+
3+
4+
``` @meta
5+
CurrentModule = HybridVariationalInference
6+
```
7+
8+
First load necessary packages.
9+
10+
``` julia
11+
using HybridVariationalInference
12+
using StableRNGs
13+
using ComponentArrays: ComponentArrays as CA
14+
using SimpleChains # for reloading the optimized problem
15+
using DistributionFits
16+
using JLD2
17+
using CairoMakie
18+
using PairPlots # scatterplot matrices
19+
```
20+
21+
After redefinig the process-based model (currently JLD2 cannot save functions),
22+
the previously optimized Problem can be loaded.
23+
24+
``` julia
25+
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
26+
# extract several covariates from xP
27+
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
28+
S1 = (CA.getdata(xPc[:S1,:])::ST)
29+
S2 = (CA.getdata(xPc[:S2,:])::ST)
30+
#
31+
# extract the parameters as row-repeated vectors
32+
n_obs = size(S1, 1)
33+
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
34+
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
35+
p1 = CA.getdata(θc[:, par]) ::VT
36+
repeat(p1', n_obs) # matrix: same for each concentration row in S1
37+
end
38+
#
39+
# each variable is a matrix (n_obs x n_site)
40+
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
41+
end
42+
```
43+
44+
``` julia
45+
fname = "intermediate/basic_cpu_results.jld2"
46+
print(abspath(fname))
47+
probo, interpreters = load(fname, "probo", "interpreters");
48+
```
49+
50+
## Sample the posterior
51+
52+
A sample of both, posterior, and predictive posterior can be obtained
53+
using function [`sample_posterior`](@ref).
54+
55+
``` julia
56+
using StableRNGs
57+
rng = StableRNG(112)
58+
n_sample_pred = 400
59+
(; θsP, θsMs) = sample_posterior(rng, probo; n_sample_pred)
60+
```
61+
62+
Lets look at the results.
63+
64+
``` julia
65+
size(θsP), size(θsMs)
66+
```
67+
68+
((1, 400), (800, 2, 400))
69+
70+
The last dimension is the number of samples, the second-last dimension is
71+
the respective parameter. `θsMs` has an additional dimension denoting
72+
the site for which parameters are samples.
73+
74+
They are ComponentArrays with the parameter dimension names that can be used:
75+
76+
``` julia
77+
θsMs[1,:r1,:] # sample of r1 of the first site
78+
```
79+
80+
### Corner plots
81+
82+
The relation between different variables can be well inspected by
83+
scatterplot matrices, also called corner plots or pair plots.
84+
`PairPlots.jl` provides a Makie-implementation of those.
85+
86+
Here, we plot the global parameters and the site-parameters for the first site.
87+
88+
``` julia
89+
i_site = 1
90+
θ1 = vcat(θsP, θsMs[i_site,:,:])
91+
θ1_nt = NamedTuple(k => CA.getdata(θ1[k,:]) for k in keys(θ1[:,1])) #
92+
plt = pairplot(θ1_nt)
93+
```
94+
95+
![](inspect_results_files/figure-commonmark/cell-9-output-1.png)
96+
97+
The plot shows that parameters for the first site, *K*₁ and *r*₁, are correlated,
98+
but that we did not model correlation with the global parameter, *K*₂.
99+
100+
Note that this plots shows only the first out of 800 sites.
101+
HVI estimated a 1602-dimensional posterior distribution including
102+
covariances among parameters.
103+
104+
### Expected values and marginal variances
105+
106+
Lets look at how the estimated uncertainty of a site parameter changes with
107+
its expected value.
108+
109+
``` julia
110+
par = :K1
111+
θmean = [mean(θsMs[s,par,:]) for s in axes(θsMs, 1)]
112+
θsd = [std(θsMs[s,par,:]) for s in axes(θsMs, 1)]
113+
fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)")
114+
scatter!(ax, θmean, θsd)
115+
fig
116+
```
117+
118+
![](inspect_results_files/figure-commonmark/cell-11-output-1.png)
119+
120+
We see that *K*₁ across sites ranges from about 0.18 to 0.25, and that
121+
its estimated uncertainty is aobut 0.034, slightly decreasing with the
122+
values of the parameter.
123+
124+
## Predictive Posterior
125+
126+
In addition to the uncertainty in parameters, we are also interested in
127+
the uncertainty of predictions, i.e. the predictive posterior.
128+
129+
We cam either run the PBM for all the parameter samples that we obtained already,
130+
using [`apply_process_model`](@ref), or use [`predict_hvi`](@ref) which combines
131+
sampling the posterior and predictive posterior and returns the additional
132+
`NamedTuple` entry `y`.
133+
134+
``` julia
135+
(; y, θsP, θsMs) = predict_hvi(rng, probo; n_sample_pred)
136+
```
137+
138+
``` julia
139+
size(y)
140+
```
141+
142+
(8, 800, 400)
143+
144+
Again, the last dimension is the sample.
145+
The other dimensions correspond to the observations we provided for the fitting:
146+
The first dimension is the observation within one site, the second dimension is the site.
147+
148+
Lets look on how the uncertainty of the 4th observations scales with its
149+
predicted magnitude across sites.
150+
151+
``` julia
152+
i_obs = 4
153+
ymean = [mean(y[i_obs,s,:]) for s in axes(θsMs, 1)]
154+
ysd = [std(y[i_obs,s,:]) for s in axes(θsMs, 1)]
155+
fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y$i_obs)",ylabel="sd(y$i_obs)")
156+
scatter!(ax, ymean, ysd)
157+
fig
158+
```
159+
160+
![](inspect_results_files/figure-commonmark/cell-14-output-1.png)
161+
162+
We see that observed values for associated substrate concentrations range about from
163+
0.51 to 0.59 with an estimated standard deviation around 0.005 that decreases
164+
with the observed value.

0 commit comments

Comments
 (0)