Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ scatterplot(θMs_true[2,:], θMs[2,:])
prob1o.θP
scatterplot(vec(y_true), vec(y_pred))

# still overestimating θMs
# still overestimating θMs and θP

() -> begin # with more iterations?
prob2 = prob1o
Expand Down
26 changes: 20 additions & 6 deletions src/AbstractHybridProblem.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
"""
Type to dispatch constructing data and network structures
for different cases of hybrid problem setups
for different cases of hybrid problem setups.

For a specific prob, provide functions that specify details
- `get_hybridproblem_MLapplicator`
- `get_hybridproblem_transforms`
- `get_hybridproblem_PBmodel`
- `get_hybridproblem_neg_logden_obs`
- `get_hybridproblem_par_templates`
- `get_hybridproblem_transforms`
- `get_hybridproblem_ϕunc`
- `get_hybridproblem_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
optionally
- `gen_hybridcase_synthetic`
- `get_hybridproblem_n_covar` (defaults to number of rows in xM in train_dataloader )
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
- `get_hybridproblem_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
- `get_hybridproblem_cor_ends` (defaults to include all correlations: `(P=(1,), M=(1,))`)

The initial value of parameters to estimate is spread
- `ϕg`: parameter of the MLapplicator: returned by `get_hybridproblem_MLapplicator`
- `ζP`: mean of the PBmodel parameters: returned by `get_hybridproblem_par_templates`
- `ϕunc`: additional parameters of the approximte posterior: returned by `get_hybridproblem_ϕunc`
"""
abstract type AbstractHybridProblem end;

Expand Down Expand Up @@ -64,6 +70,13 @@ Provide tuple of templates of ComponentVectors `θP` and `θM`.
"""
function get_hybridproblem_par_templates end

"""
get_hybridproblem_ϕunc(::AbstractHybridProblem; scenario)

Provide a ComponentArray of the initial additional parameters of the approximate posterior.
"""
function get_hybridproblem_ϕunc end

"""
get_hybridproblem_transforms(::AbstractHybridProblem; scenario)

Expand Down Expand Up @@ -143,7 +156,7 @@ function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenari
end

"""
get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario)
get_hybridproblem_cor_ends(prob::AbstractHybridProblem; scenario)

Specify blocks in correlation matrices among parameters.
Returns a NamedTuple.
Expand All @@ -159,6 +172,7 @@ then the first subrange starts at position 1 and the second subrange starts at p
If there is only single block of all ML-predicted parameters being correlated
with each other then this block starts at position 1: `(P=(1,3), M=(1,))`.
"""
function get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario = ())
(P = (1,), M = (1,))
function get_hybridproblem_cor_ends(prob::AbstractHybridProblem; scenario = ())
pt = get_hybridproblem_par_templates(prob; scenario)
(P = [length(pt.θP)], M = [length(pt.θM)])
end
24 changes: 14 additions & 10 deletions src/HybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
py
transP
transM
cor_starts # = (P=(1,),M=(1,))
cor_ends # = (P=(1,),M=(1,))
get_train_loader
# inner constructor to constrain the types
function HybridProblem(
Expand All @@ -20,8 +20,8 @@
#train_loader::DataLoader,
# return a function that constructs the trainloader based on n_batch
get_train_loader::Function,
cor_starts::NamedTuple = (P = (1,), M = (1,)))
new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, get_train_loader)
cor_ends::NamedTuple = (P = [length(θP)], M = [length(θM)]))
new(θP, θM, f, g, ϕg, py, transM, transP, cor_ends, get_train_loader)
end
end

Expand All @@ -45,8 +45,8 @@
get_hybridproblem_train_dataloader(rng::AbstractRNG, prob; scenario, kwargs...)
end
end
cor_starts = get_hybridproblem_cor_starts(prob; scenario)
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts)
cor_ends = get_hybridproblem_cor_ends(prob; scenario)
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_ends)

Check warning on line 49 in src/HybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/HybridProblem.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
end

function update(prob::HybridProblem;
Expand All @@ -58,7 +58,7 @@
transM::Union{Function, Bijectors.Transform} = prob.transM,
transP::Union{Function, Bijectors.Transform} = prob.transP,
get_train_loader::Function = prob.get_train_loader,
cor_starts::NamedTuple = prob.cor_starts)
cor_ends::NamedTuple = prob.cor_ends)
# prob.θP = θP
# prob.θM = θM
# prob.f = f
Expand All @@ -67,15 +67,19 @@
# prob.py = py
# prob.transM = transM
# prob.transP = transP
# prob.cor_starts = cor_starts
# prob.cor_ends = cor_ends
# prob.get_train_loader = get_train_loader
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts)
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_ends)

Check warning on line 72 in src/HybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/HybridProblem.jl#L72

Added line #L72 was not covered by tests
end

function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ())
(; θP = prob.θP, θM = prob.θM)
end

function get_hybridproblem_ϕunc(prob::HybridProblem; scenario::NTuple = ())
prob.ϕunc

Check warning on line 80 in src/HybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/HybridProblem.jl#L79-L80

Added lines #L79 - L80 were not covered by tests
end

function get_hybridproblem_neg_logden_obs(prob::HybridProblem; scenario::NTuple = ())
prob.py
end
Expand All @@ -102,8 +106,8 @@
return prob.get_train_loader(rng; kwargs...)
end

function get_hybridproblem_cor_starts(prob::HybridProblem; scenario = ())
prob.cor_starts
function get_hybridproblem_cor_ends(prob::HybridProblem; scenario = ())
prob.cor_ends
end

# function get_hybridproblem_float_type(prob::HybridProblem; scenario::NTuple = ())
Expand Down
14 changes: 8 additions & 6 deletions src/HybridSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,19 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
scenario, rng = Random.default_rng(), kwargs...)
par_templates = get_hybridproblem_par_templates(prob; scenario)
(; θP, θM) = par_templates
cor_ends = get_hybridproblem_cor_ends(prob; scenario)
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
θP, θM, ϕg0, solver.n_batch; transP, transM);
θP, θM, cor_ends, ϕg0, solver.n_batch; transP, transM);
use_gpu = (:use_Flux ∈ scenario)
ϕ0 = use_gpu ? CuArray(ϕ) : ϕ # TODO replace CuArray by something more general
train_loader = get_hybridproblem_train_dataloader(rng, prob; scenario, solver.n_batch)
f = get_hybridproblem_PBmodel(prob; scenario)
py = get_hybridproblem_neg_logden_obs(prob; scenario)
y_global_o = Float32[] # TODO
loss_elbo = get_loss_elbo(g, transPMs_batch, f, py, y_global_o, interpreters; solver.n_MC)
loss_elbo = get_loss_elbo(
g, transPMs_batch, f, py, y_global_o, interpreters; solver.n_MC, cor_ends)
# test loss function once
l0 = loss_elbo(ϕ0, rng, first(train_loader)...)
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_elbo(ϕ, rng, data...)[1],
Expand All @@ -84,12 +86,12 @@ The loss function takes in addition to ϕ, data that changes with minibatch
- xP: drivers for the processmodel: Iterator of size n_site
- y_o, y_unc: matrix of observations and uncertainties, sites in columns
"""
function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters; n_MC)
let g = g, transPMs = transPMs, f = f, py=py, y_o_global = y_o_global, n_MC = n_MC
interpreters = map(get_concrete, interpreters)
function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters; n_MC, cor_ends)
let g = g, transPMs = transPMs, f = f, py=py, y_o_global = y_o_global, n_MC = n_MC,
cor_ends = cor_ends, interpreters = map(get_concrete, interpreters)
function loss_elbo(ϕ, rng, xM, xP, y_o, y_unc)
neg_elbo_transnorm_gf(rng, ϕ, g, transPMs, f, py,
xM, xP, y_o, y_unc, interpreters; n_MC)
xM, xP, y_o, y_unc, interpreters; n_MC, cor_ends)
end
end
end
Expand Down
5 changes: 3 additions & 2 deletions src/HybridVariationalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_
get_hybridproblem_par_templates, get_hybridproblem_transforms, get_hybridproblem_train_dataloader,
get_hybridproblem_neg_logden_obs,
get_hybridproblem_n_covar,
get_hybridproblem_cor_ends,
#update,
gen_cov_pred
include("AbstractHybridProblem.jl")
Expand All @@ -53,13 +54,13 @@ include("util_ca.jl")
export neg_logden_indep_normal, entropy_MvNormal
include("logden_normal.jl")

export get_ca_starts
export get_ca_starts, get_ca_ends, get_cor_count
include("cholesky.jl")

export neg_elbo_transnorm_gf, predict_gf
include("elbo.jl")

export init_hybrid_params
export init_hybrid_params, init_hybrid_ϕunc
include("init_hybrid_params.jl")

export AbstractHybridSolver, HybridPointSolver, HybridPosteriorSolver
Expand Down
Loading