Skip to content

Commit 0d91f84

Browse files
authored
Merge pull request #17 from EarthyScience/dev
allow specifying initial uncertainties and handle empty correlation parameter vector
2 parents 750ab58 + 81c7f44 commit 0d91f84

14 files changed

+317
-177
lines changed

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ scatterplot(θMs_true[2,:], θMs[2,:])
5555
prob1o.θP
5656
scatterplot(vec(y_true), vec(y_pred))
5757

58-
# still overestimating θMs
58+
# still overestimating θMs and θP
5959

6060
() -> begin # with more iterations?
6161
prob2 = prob1o

src/AbstractHybridProblem.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
"""
22
Type to dispatch constructing data and network structures
3-
for different cases of hybrid problem setups
3+
for different cases of hybrid problem setups.
44
55
For a specific prob, provide functions that specify details
66
- `get_hybridproblem_MLapplicator`
7+
- `get_hybridproblem_transforms`
78
- `get_hybridproblem_PBmodel`
89
- `get_hybridproblem_neg_logden_obs`
910
- `get_hybridproblem_par_templates`
10-
- `get_hybridproblem_transforms`
11+
- `get_hybridproblem_ϕunc`
1112
- `get_hybridproblem_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
1213
optionally
1314
- `gen_hybridcase_synthetic`
1415
- `get_hybridproblem_n_covar` (defaults to number of rows in xM in train_dataloader )
1516
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
16-
- `get_hybridproblem_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
17+
- `get_hybridproblem_cor_ends` (defaults to include all correlations: `(P=(1,), M=(1,))`)
18+
19+
The initial value of parameters to estimate is spread
20+
- `ϕg`: parameter of the MLapplicator: returned by `get_hybridproblem_MLapplicator`
21+
- `ζP`: mean of the PBmodel parameters: returned by `get_hybridproblem_par_templates`
22+
- `ϕunc`: additional parameters of the approximte posterior: returned by `get_hybridproblem_ϕunc`
1723
"""
1824
abstract type AbstractHybridProblem end;
1925

@@ -64,6 +70,13 @@ Provide tuple of templates of ComponentVectors `θP` and `θM`.
6470
"""
6571
function get_hybridproblem_par_templates end
6672

73+
"""
74+
get_hybridproblem_ϕunc(::AbstractHybridProblem; scenario)
75+
76+
Provide a ComponentArray of the initial additional parameters of the approximate posterior.
77+
"""
78+
function get_hybridproblem_ϕunc end
79+
6780
"""
6881
get_hybridproblem_transforms(::AbstractHybridProblem; scenario)
6982
@@ -143,7 +156,7 @@ function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenari
143156
end
144157

145158
"""
146-
get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario)
159+
get_hybridproblem_cor_ends(prob::AbstractHybridProblem; scenario)
147160
148161
Specify blocks in correlation matrices among parameters.
149162
Returns a NamedTuple.
@@ -159,6 +172,7 @@ then the first subrange starts at position 1 and the second subrange starts at p
159172
If there is only single block of all ML-predicted parameters being correlated
160173
with each other then this block starts at position 1: `(P=(1,3), M=(1,))`.
161174
"""
162-
function get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario = ())
163-
(P = (1,), M = (1,))
175+
function get_hybridproblem_cor_ends(prob::AbstractHybridProblem; scenario = ())
176+
pt = get_hybridproblem_par_templates(prob; scenario)
177+
(P = [length(pt.θP)], M = [length(pt.θM)])
164178
end

src/HybridProblem.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct HybridProblem <: AbstractHybridProblem
77
py
88
transP
99
transM
10-
cor_starts # = (P=(1,),M=(1,))
10+
cor_ends # = (P=(1,),M=(1,))
1111
get_train_loader
1212
# inner constructor to constrain the types
1313
function HybridProblem(
@@ -20,8 +20,8 @@ struct HybridProblem <: AbstractHybridProblem
2020
#train_loader::DataLoader,
2121
# return a function that constructs the trainloader based on n_batch
2222
get_train_loader::Function,
23-
cor_starts::NamedTuple = (P = (1,), M = (1,)))
24-
new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, get_train_loader)
23+
cor_ends::NamedTuple = (P = [length(θP)], M = [length(θM)]))
24+
new(θP, θM, f, g, ϕg, py, transM, transP, cor_ends, get_train_loader)
2525
end
2626
end
2727

@@ -45,8 +45,8 @@ function HybridProblem(prob::AbstractHybridProblem; scenario = ())
4545
get_hybridproblem_train_dataloader(rng::AbstractRNG, prob; scenario, kwargs...)
4646
end
4747
end
48-
cor_starts = get_hybridproblem_cor_starts(prob; scenario)
49-
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts)
48+
cor_ends = get_hybridproblem_cor_ends(prob; scenario)
49+
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_ends)
5050
end
5151

5252
function update(prob::HybridProblem;
@@ -58,7 +58,7 @@ function update(prob::HybridProblem;
5858
transM::Union{Function, Bijectors.Transform} = prob.transM,
5959
transP::Union{Function, Bijectors.Transform} = prob.transP,
6060
get_train_loader::Function = prob.get_train_loader,
61-
cor_starts::NamedTuple = prob.cor_starts)
61+
cor_ends::NamedTuple = prob.cor_ends)
6262
# prob.θP = θP
6363
# prob.θM = θM
6464
# prob.f = f
@@ -67,15 +67,19 @@ function update(prob::HybridProblem;
6767
# prob.py = py
6868
# prob.transM = transM
6969
# prob.transP = transP
70-
# prob.cor_starts = cor_starts
70+
# prob.cor_ends = cor_ends
7171
# prob.get_train_loader = get_train_loader
72-
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts)
72+
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_ends)
7373
end
7474

7575
function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ())
7676
(; θP = prob.θP, θM = prob.θM)
7777
end
7878

79+
function get_hybridproblem_ϕunc(prob::HybridProblem; scenario::NTuple = ())
80+
prob.ϕunc
81+
end
82+
7983
function get_hybridproblem_neg_logden_obs(prob::HybridProblem; scenario::NTuple = ())
8084
prob.py
8185
end
@@ -102,8 +106,8 @@ function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::HybridProble
102106
return prob.get_train_loader(rng; kwargs...)
103107
end
104108

105-
function get_hybridproblem_cor_starts(prob::HybridProblem; scenario = ())
106-
prob.cor_starts
109+
function get_hybridproblem_cor_ends(prob::HybridProblem; scenario = ())
110+
prob.cor_ends
107111
end
108112

109113
# function get_hybridproblem_float_type(prob::HybridProblem; scenario::NTuple = ())

src/HybridSolver.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,19 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
4949
scenario, rng = Random.default_rng(), kwargs...)
5050
par_templates = get_hybridproblem_par_templates(prob; scenario)
5151
(; θP, θM) = par_templates
52+
cor_ends = get_hybridproblem_cor_ends(prob; scenario)
5253
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
5354
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
5455
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
55-
θP, θM, ϕg0, solver.n_batch; transP, transM);
56+
θP, θM, cor_ends, ϕg0, solver.n_batch; transP, transM);
5657
use_gpu = (:use_Flux scenario)
5758
ϕ0 = use_gpu ? CuArray(ϕ) : ϕ # TODO replace CuArray by something more general
5859
train_loader = get_hybridproblem_train_dataloader(rng, prob; scenario, solver.n_batch)
5960
f = get_hybridproblem_PBmodel(prob; scenario)
6061
py = get_hybridproblem_neg_logden_obs(prob; scenario)
6162
y_global_o = Float32[] # TODO
62-
loss_elbo = get_loss_elbo(g, transPMs_batch, f, py, y_global_o, interpreters; solver.n_MC)
63+
loss_elbo = get_loss_elbo(
64+
g, transPMs_batch, f, py, y_global_o, interpreters; solver.n_MC, cor_ends)
6365
# test loss function once
6466
l0 = loss_elbo(ϕ0, rng, first(train_loader)...)
6567
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_elbo(ϕ, rng, data...)[1],
@@ -84,12 +86,12 @@ The loss function takes in addition to ϕ, data that changes with minibatch
8486
- xP: drivers for the processmodel: Iterator of size n_site
8587
- y_o, y_unc: matrix of observations and uncertainties, sites in columns
8688
"""
87-
function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters; n_MC)
88-
let g = g, transPMs = transPMs, f = f, py=py, y_o_global = y_o_global, n_MC = n_MC
89-
interpreters = map(get_concrete, interpreters)
89+
function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters; n_MC, cor_ends)
90+
let g = g, transPMs = transPMs, f = f, py=py, y_o_global = y_o_global, n_MC = n_MC,
91+
cor_ends = cor_ends, interpreters = map(get_concrete, interpreters)
9092
function loss_elbo(ϕ, rng, xM, xP, y_o, y_unc)
9193
neg_elbo_transnorm_gf(rng, ϕ, g, transPMs, f, py,
92-
xM, xP, y_o, y_unc, interpreters; n_MC)
94+
xM, xP, y_o, y_unc, interpreters; n_MC, cor_ends)
9395
end
9496
end
9597
end

src/HybridVariationalInference.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_
3131
get_hybridproblem_par_templates, get_hybridproblem_transforms, get_hybridproblem_train_dataloader,
3232
get_hybridproblem_neg_logden_obs,
3333
get_hybridproblem_n_covar,
34+
get_hybridproblem_cor_ends,
3435
#update,
3536
gen_cov_pred
3637
include("AbstractHybridProblem.jl")
@@ -53,13 +54,13 @@ include("util_ca.jl")
5354
export neg_logden_indep_normal, entropy_MvNormal
5455
include("logden_normal.jl")
5556

56-
export get_ca_starts
57+
export get_ca_starts, get_ca_ends, get_cor_count
5758
include("cholesky.jl")
5859

5960
export neg_elbo_transnorm_gf, predict_gf
6061
include("elbo.jl")
6162

62-
export init_hybrid_params
63+
export init_hybrid_params, init_hybrid_ϕunc
6364
include("init_hybrid_params.jl")
6465

6566
export AbstractHybridSolver, HybridPointSolver, HybridPosteriorSolver

0 commit comments

Comments
 (0)