Skip to content

Commit 900b5f3

Browse files
authored
Merge pull request #28 from EarthyScience/dev
Account for priors during inversion
2 parents 739044b + 7a34995 commit 900b5f3

39 files changed

+445
-212
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1212
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1313
DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
1414
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
15+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1516
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1617
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1920
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
21+
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
2022
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
2123
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2224
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -47,13 +49,15 @@ CommonSolve = "0.2.4"
4749
ComponentArrays = "0.15.19"
4850
DistributionFits = "0.3.9"
4951
Distributions = "0.25.117"
52+
FillArrays = "1.13.0"
5053
Flux = "0.14, 0.15, 0.16"
5154
Functors = "0.4, 0.5"
5255
GPUArraysCore = "0.1, 0.2"
5356
LinearAlgebra = "1.10"
5457
Lux = "1.4.2"
5558
MLDataDevices = "1.5, 1.6"
5659
MLUtils = "0.4.5"
60+
Missings = "1.2.0"
5761
Optimization = "3.19.3, 4"
5862
Random = "1.10.0"
5963
SimpleChains = "0.4"

_typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ extend-exclude = ["docs/src_stash/"]
44
[default.extend-words]
55
SOM = "SOM"
66
negLogLik = "negLogLik"
7+
Missings = "Missings"

dev/doubleMM.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity
3232

3333
#------ setup synthetic data and training data loader
3434
prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario);
35-
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
35+
(; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc
3636
) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
3737
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
3838
ζP_true, ζMs_true = log.(θP_true), log.(θMs_true)
@@ -59,7 +59,7 @@ n_epoch = 80
5959
maxiters = n_batches_in_epoch * n_epoch);
6060
# update the problem with optimized parameters
6161
prob0o = prob1o =probo;
62-
y_pred_global, y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true));
62+
y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true));
6363
# @descend_code_warntype gf(prob0o; scenario)
6464
#@usingany UnicodePlots
6565
plt = scatterplot(θMs_true'[:, 1], θMs[:, 1]);
@@ -77,7 +77,7 @@ histogram(vec(y_pred) - vec(y_true)) # predictions centered around y_o (or y_tru
7777
(; ϕ, resopt) = solve(prob0o, solver1; scenario, rng,
7878
callback = callback_loss(20), maxiters = 400)
7979
prob1o = HybridProblem(prob0o; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP)
80-
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario)
80+
y_pred, θMs = gf(prob1o, xM, xP; scenario)
8181
scatterplot(θMs_true[1, :], θMs[1, :])
8282
scatterplot(θMs_true[2, :], θMs[2, :])
8383
prob1o.θP
@@ -91,7 +91,7 @@ end
9191
(; ϕ, resopt) = solve(prob2, solver1; scenario, rng,
9292
callback = callback_loss(20), maxiters = 600)
9393
prob2o = HybridProblem(prob2; ϕg = collect.ϕg), θP = ϕ.θP)
94-
y_pred_global, y_pred, θMs = gf(prob2o, xM, xP)
94+
y_pred, θMs = gf(prob2o, xM, xP)
9595
prob2o.θP
9696
end
9797

@@ -127,7 +127,7 @@ end
127127
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
128128
callback = callback_loss(50), maxiters = 600)
129129
prob3o = HybridProblem(prob3; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP)
130-
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario)
130+
y_pred, θMs = gf(prob3o, xM, xP; scenario)
131131
scatterplot(θMs_true[2, :], θMs[2, :])
132132
prob3o.θP
133133
scatterplot(vec(y_true), vec(y_pred))
@@ -173,7 +173,7 @@ solver_post = HybridPosteriorSolver(; alg = OptimizationOptimisers.Adam(0.01), n
173173
(y1, θsP1, θsMs1) = (y, θsP, θsMs);
174174

175175
() -> begin # prediction with fitted parameters (should be smaller than mean)
176-
y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario)
176+
y_pred2, θMs = gf(prob1o, xM, xP; scenario)
177177
scatterplot(θMs_true[1, :], θMs[1, :])
178178
scatterplot(θMs_true[2, :], θMs[2, :])
179179
hcat(θP_true, θP) # all parameters overestimated
@@ -366,7 +366,7 @@ end
366366
# ζMs = invt.transM.(θMs_i)
367367
# _f = get_hybridproblem_PBmodel(probo; scenario)
368368
# y_site = map(eachcol(θPs), θMs_i) do θP, θM
369-
# y_global, y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]])
369+
# y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]])
370370
# y[:,1]
371371
# end |> stack
372372
nLs = get_hybridproblem_neg_logden_obs(

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ makedocs(;
2323
#"Test quarto markdown" => "tutorials/test1.md",
2424
],
2525
"How to" => [
26+
".. use GPU" => "tutorials/lux_gpu.md",
2627
".. model independent parameters" => "tutorials/blocks_corr.md",
2728
".. model site-global corr" => "tutorials/corr_site_global.md",
28-
".. use GPU" => "tutorials/lux_gpu.md",
2929
],
3030
"Explanation" => [
3131
#"Theory" => "explanation/theory_hvi.md", TODO activate when paper is published

docs/src/tutorials/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
88
HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64"
99
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1010
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
11+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1112
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1213
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
1314
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"

docs/src/tutorials/basic_cpu.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,14 @@ HVI is an approximate bayesian analysis and combines prior information on
104104
the parameters with the model and observed data.
105105

106106
Here, we provide a wide prior by fitting a Lognormal distributions to
107-
- the mean corresponding to the initial value provided above
108-
- the 0.95-quantile 3 times the mean
107+
- the mode corresponding to the initial value provided above
108+
- the 0.95-quantile 3 times the mode
109109
using the `DistributionFits.jl` package.
110110

111111
``` julia
112112
θall = vcat(θP, θM)
113113
priors_dict = Dict{Symbol, Distribution}(
114-
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95)))
114+
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))
115115
```
116116

117117
## Observations, model drivers and covariates

docs/src/tutorials/basic_cpu.qmd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,14 @@ HVI is an approximate bayesian analysis and combines prior information on
109109
the parameters with the model and observed data.
110110

111111
Here, we provide a wide prior by fitting a Lognormal distributions to
112-
- the mean corresponding to the initial value provided above
113-
- the 0.95-quantile 3 times the mean
112+
- the mode corresponding to the initial value provided above
113+
- the 0.95-quantile 3 times the mode
114114
using the `DistributionFits.jl` package.
115115

116116
```{julia}
117117
θall = vcat(θP, θM)
118118
priors_dict = Dict{Symbol, Distribution}(
119-
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95)))
119+
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))
120120
```
121121

122122
## Observations, model drivers and covariates
@@ -138,7 +138,7 @@ rng = StableRNG(111)
138138
#| echo: false
139139
#| eval: false
140140
() -> begin
141-
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) =
141+
(; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc) =
142142
gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))
143143
end
144144
```
-6.08 KB
Loading
981 Bytes
Loading
-7.25 KB
Loading

0 commit comments

Comments
 (0)