Skip to content

Commit 739044b

Browse files
authored
Merge pull request #27 from EarthyScience/dev
add Howtos to docu
2 parents 8e6e08b + 2d7d7fe commit 739044b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+6284
-933
lines changed

_typos.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
[files]
2+
extend-exclude = ["docs/src_stash/"]
3+
14
[default.extend-words]
25
SOM = "SOM"
36
negLogLik = "negLogLik"

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true)
572572
trans_mP=StackedArray(transP, size(ζsP, 2))
573573
trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3))
574574
θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs)
575-
y = apply_process_model(θsP, θsMs, f, xP)
575+
y = f(θsP, θsMs, f, xP)
576576
#(; y, θsP, θsMs) = HVI.apply_f_trans(ζsP, ζsMs, f_allsites, xP; transP, transM);
577577
(y_hmc, θsP_hmc, θsMs_hmc) = (; y, θsP, θsMs);
578578

docs/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/.quarto/

docs/_quarto.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
project:
2+
title: "HybridVariationInference documentation"
3+
render:
4+
- src/tutorials/basic_cpu.qmd
5+
- src/tutorials/*.qmd
6+
7+
8+
9+

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ makedocs(;
2323
#"Test quarto markdown" => "tutorials/test1.md",
2424
],
2525
"How to" => [
26-
#".. model independent parameters" => "tutorials/how_to_guides/blocks_corr_site.md",
27-
#".. model site-global corr" => "tutorials/how_to_guides/corr_site_global.md",
26+
".. model independent parameters" => "tutorials/blocks_corr.md",
27+
".. model site-global corr" => "tutorials/corr_site_global.md",
28+
".. use GPU" => "tutorials/lux_gpu.md",
2829
],
2930
"Explanation" => [
3031
#"Theory" => "explanation/theory_hvi.md", TODO activate when paper is published

docs/src/tutorials/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
45
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
56
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
67
DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
78
HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64"
89
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
10+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
911
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1012
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
1113
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
@@ -14,3 +16,4 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1416
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1517
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1618
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
19+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

docs/src/tutorials/basic_cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ epochs of the optimization.
327327
``` julia
328328
(; probo) = solve(probo_sites, solver; rng,
329329
callback = callback_loss(100), # output during fitting
330-
epochs = 10,
330+
epochs = 20,
331331
#is_inferred = Val(true), # activate type-checks
332332
);
333333
```

docs/src/tutorials/basic_cpu.qmd

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ execute:
77
daemon: 3600
88
format:
99
commonmark:
10-
variant: -raw_html
10+
variant: -raw_html+tex_math_dollars
1111
wrap: preserve
1212
bibliography: twutz_txt.bib
1313
---
@@ -316,7 +316,25 @@ For the parameters, one row corresponds to
316316
one site. For the drivers and predictions, one column corresponds to one site.
317317

318318

319-
{{< include _pbm_matrix.qmd >}}
319+
```{julia}
320+
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
321+
# extract several covariates from xP
322+
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
323+
S1 = (CA.getdata(xPc[:S1,:])::ST)
324+
S2 = (CA.getdata(xPc[:S2,:])::ST)
325+
#
326+
# extract the parameters as row-repeated vectors
327+
n_obs = size(S1, 1)
328+
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
329+
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
330+
p1 = CA.getdata(θc[:, par]) ::VT
331+
repeat(p1', n_obs) # matrix: same for each concentration row in S1
332+
end
333+
#
334+
# each variable is a matrix (n_obs x n_site)
335+
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
336+
end
337+
```
320338

321339
Again, the function should not rely on the order of parameters but use symbolic indexing
322340
to extract the parameter vectors. For type stability of this symbolic indexing,
@@ -345,7 +363,7 @@ epochs of the optimization.
345363
```{julia}
346364
(; probo) = solve(probo_sites, solver; rng,
347365
callback = callback_loss(100), # output during fitting
348-
epochs = 10,
366+
epochs = 20,
349367
#is_inferred = Val(true), # activate type-checks
350368
);
351369
```

docs/src/tutorials/blocks_corr.md

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# How to model indenpendent parameter-blocks in the posterior
2+
3+
4+
``` @meta
5+
CurrentModule = HybridVariationalInference
6+
```
7+
8+
This guide shows how to configure independent parameter-blocks in the correlations
9+
of the posterior.
10+
11+
## Motivation
12+
13+
Modelling all correlations among global and site PBM-parameters respectively
14+
requires many degrees of freedom.
15+
16+
To decrease the number of parameters to estimate, HVI allows to decompose the
17+
correlations into independent sub-blocks of parameters.
18+
19+
First load necessary packages.
20+
21+
``` julia
22+
using HybridVariationalInference
23+
using ComponentArrays: ComponentArrays as CA
24+
using Bijectors
25+
using SimpleChains
26+
using MLUtils
27+
using JLD2
28+
using Random
29+
using CairoMakie
30+
using PairPlots # scatterplot matrices
31+
```
32+
33+
This tutorial reuses and modifies the fitted object saved at the end of the
34+
[Basic workflow without GPU](@ref) tutorial.
35+
36+
``` julia
37+
fname = "intermediate/basic_cpu_results.jld2"
38+
print(abspath(fname))
39+
prob = probo_cor = load(fname, "probo");
40+
```
41+
42+
## Specifying blocks in correlation structure
43+
44+
HVI models the posterior of the parameters at unconstrained scale using a
45+
multivariate normal distribution. It estimates a parameterization of the
46+
associated blocks in the correlation matrx and requires a specification
47+
of the block-structure.
48+
49+
This is done by specifying the positions of the end of the blocks for
50+
the global (P) and the site-specific parameters (M) respectively using
51+
a `NamedTuple` of integer vectors.
52+
53+
The defaults specifies a single entry, meaning, there is only one big
54+
block respectively, spanning all parameters.
55+
56+
``` julia
57+
cor_ends0 = (P=[length(prob.θP)], M=[length(prob.θM)])
58+
```
59+
60+
(P = [1], M = [2])
61+
62+
The following specification models one-entry blocks for each each parameter
63+
in the correlation block the site parameters, i.e. treating all parameters
64+
independently with not modelling any correlations between them.
65+
66+
``` julia
67+
cor_ends = (P=[length(prob.θP)], M=1:length(prob.θM))
68+
```
69+
70+
(P = [1], M = 1:2)
71+
72+
## Reinitialize parameters for the posterior approximation.
73+
74+
HVI uses additional fitted parameters to represent the means and the
75+
covariance matrix of the posterior distribution of model parameters.
76+
With fewer correlations, also the number of those parameters changes,
77+
and those parameters must be reinitialized after changing the block structure in
78+
the correlation matrix.
79+
80+
Here, we obtain construct initial estimates. using [`init_hybrid_ϕunc`](@ref)
81+
82+
``` julia
83+
ϕunc = init_hybrid_ϕunc(cor_ends, zero(eltype(prob.θM)))
84+
```
85+
86+
In this two-site parameter case, the the blocked structure saves only one degree of freedom:
87+
88+
``` julia
89+
length(ϕunc), length(probo_cor.ϕunc)
90+
```
91+
92+
(5, 6)
93+
94+
## Update the problem and redo the inversion
95+
96+
``` julia
97+
prob_ind = HybridProblem(prob; cor_ends, ϕunc)
98+
```
99+
100+
``` julia
101+
using OptimizationOptimisers
102+
import Zygote
103+
104+
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
105+
106+
(; probo) = solve(prob_ind, solver;
107+
callback = callback_loss(100), # output during fitting
108+
epochs = 20,
109+
); probo_ind = probo;
110+
```
111+
112+
## Compare the correated vs. uncorrelated posterior
113+
114+
First, draw a sample.
115+
116+
``` julia
117+
n_sample_pred = 400
118+
(y_cor, θsP_cor, θsMs_cor) = (; y, θsP, θsMs) = predict_hvi(
119+
Random.default_rng(), probo_cor; n_sample_pred)
120+
(y_ind, θsP_ind, θsMs_ind) = (; y, θsP, θsMs) = predict_hvi(
121+
Random.default_rng(), probo_ind; n_sample_pred)
122+
```
123+
124+
``` julia
125+
i_site = 1
126+
θ1 = vcat(θsP_ind, θsMs_ind[i_site,:,:])
127+
θ1_nt = NamedTuple(k => CA.getdata(θ1[k,:]) for k in keys(θ1[:,1])) #
128+
plt = pairplot(θ1_nt)
129+
```
130+
131+
![](blocks_corr_files/figure-commonmark/cell-11-output-1.png)
132+
133+
The corner plot of the independent-parameters estimate shows
134+
no correlations between site parameters, *r*₁ and *K*₁.
135+
136+
``` julia
137+
i_out = 4
138+
fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y)",ylabel="sd(y)")
139+
ymean_cor = [mean(y_cor[i_out,s,:]) for s in axes(y_cor, 2)]
140+
ysd_cor = [std(y_cor[i_out,s,:]) for s in axes(y_cor, 2)]
141+
scatter!(ax, ymean_cor, ysd_cor, label="correlated")
142+
ymean_ind = [mean(y_ind[i_out,s,:]) for s in axes(y_ind, 2)]
143+
ysd_ind = [std(y_ind[i_out,s,:]) for s in axes(y_ind, 2)]
144+
scatter!(ax, ymean_ind, ysd_ind, label="independent")
145+
axislegend(ax, unique=true)
146+
fig
147+
```
148+
149+
![](blocks_corr_files/figure-commonmark/cell-12-output-1.png)
150+
151+
``` julia
152+
plot_sd_vs_mean = (par) -> begin
153+
fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)")
154+
θmean_cor = [mean(θsMs_cor[s,par,:]) for s in axes(θsMs_cor, 1)]
155+
θsd_cor = [std(θsMs_cor[s,par,:]) for s in axes(θsMs_cor, 1)]
156+
scatter!(ax, θmean_cor, θsd_cor, label="correlated")
157+
θmean_ind = [mean(θsMs_ind[s,par,:]) for s in axes(θsMs_ind, 1)]
158+
θsd_ind = [std(θsMs_ind[s,par,:]) for s in axes(θsMs_ind, 1)]
159+
scatter!(ax, θmean_ind, θsd_ind, label="independent")
160+
axislegend(ax, unique=true)
161+
fig
162+
end
163+
plot_sd_vs_mean(:K1)
164+
```
165+
166+
![](blocks_corr_files/figure-commonmark/cell-13-output-1.png)
167+
168+
The inversion that neglects correlations among site parameters results in
169+
the same magnitude of estimated uncertainty of predictions.
170+
However, the uncertainty of the model parameters is severely underestimated
171+
in this example.

0 commit comments

Comments
 (0)