Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@ using OptimizationOptimisers
using Bijectors
using UnicodePlots

const case = DoubleMM.DoubleMMCase()
const prob = DoubleMM.DoubleMMCase()
scenario = (:default,)
rng = StableRNG(111)

par_templates = get_hybridcase_par_templates(case; scenario)
par_templates = get_hybridproblem_par_templates(prob; scenario)

#n_covar = get_hybridcase_n_covar(case; scenario)
#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
#n_covar = get_hybridproblem_n_covar(prob; scenario)
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)

(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
) = gen_hybridcase_synthetic(rng, case; scenario);
) = gen_hybridcase_synthetic(rng, prob; scenario);

n_covar = size(xM,1)


#----- fit g to θMs_true
g, ϕg0 = get_hybridcase_MLapplicator(case; scenario);
(; transP, transM) = get_hybridcase_transforms(case; scenario)
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)

function loss_g(ϕg, x, g, transM)
ζMs = g(x, ϕg) # predict the log of the parameters
Expand All @@ -52,8 +52,8 @@ res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), max
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
scatterplot(vec(θMs_true), vec(θMs_pred))

f = get_hybridcase_PBmodel(case; scenario)
py = get_hybridcase_neg_logden_obs(case; scenario)
f = get_hybridproblem_PBmodel(prob; scenario)
py = get_hybridproblem_neg_logden_obs(prob; scenario)

#----------- fit g and θP to y_o
() -> begin
Expand Down Expand Up @@ -85,8 +85,8 @@ end

#---------- HVI
n_MC = 3
(; transP, transM) = get_hybridcase_transforms(case; scenario)
FT = get_hybridcase_float_type(case; scenario)
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
FT = get_hybridproblem_float_type(prob; scenario)

(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
Expand Down Expand Up @@ -167,7 +167,7 @@ mean_σ_o_MC = 0.006042
ϕ = CA.getdata(ϕ_ini) |> Flux.gpu;
xM_gpu = xM |> Flux.gpu;
scenario_flux = (scenario..., :use_Flux)
g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux);
g_flux, _ = get_hybridproblem_MLapplicator(prob; scenario = scenario_flux);

# otpimize using LUX
() -> begin
Expand Down Expand Up @@ -205,7 +205,7 @@ gr = Zygote.gradient(fcost,
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...)

train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux))
#train_loader = get_hybridproblem_train_dataloader(prob, rng; scenario = (scenario..., :use_Flux))

optf = Optimization.OptimizationFunction(
(ϕ, data) -> begin
Expand Down
10 changes: 5 additions & 5 deletions ext/HybridVariationalInferenceFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
# end

function HVI.construct_3layer_MLApplicator(
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux};
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:Flux};
scenario::NTuple = ())
(;θM) = get_hybridcase_par_templates(case; scenario)
(;θM) = get_hybridproblem_par_templates(prob; scenario)

Check warning on line 41 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L41

Added line #L41 was not covered by tests
n_out = length(θM)
n_covar = get_hybridcase_n_covar(case; scenario)
#(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
float_type = get_hybridcase_float_type(case; scenario)
n_covar = get_hybridproblem_n_covar(prob; scenario)

Check warning on line 43 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L43

Added line #L43 was not covered by tests
#(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario)
float_type = get_hybridproblem_float_type(prob; scenario)

Check warning on line 45 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L45

Added line #L45 was not covered by tests
is_using_dropout = :use_dropout ∈ scenario
is_using_dropout && error("dropout scenario not supported with Flux yet.")
g_chain = Flux.Chain(
Expand Down
8 changes: 4 additions & 4 deletions ext/HybridVariationalInferenceSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ end
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)

function HVI.construct_3layer_MLApplicator(
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains};
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains};
scenario::NTuple = ())
n_covar = get_hybridcase_n_covar(case; scenario)
FloatType = get_hybridcase_float_type(case; scenario)
(;θM) = get_hybridcase_par_templates(case; scenario)
n_covar = get_hybridproblem_n_covar(prob; scenario)
FloatType = get_hybridproblem_float_type(prob; scenario)
(;θM) = get_hybridproblem_par_templates(prob; scenario)
n_out = length(θM)
is_using_dropout = :use_dropout ∈ scenario
g_chain = if is_using_dropout
Expand Down
82 changes: 41 additions & 41 deletions src/hybrid_case.jl → src/AbstractHybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,40 @@
Type to dispatch constructing data and network structures
for different cases of hybrid problem setups

For a specific case, provide functions that specify details
- `get_hybridcase_MLapplicator`
- `get_hybridcase_PBmodel`
- `get_hybridcase_neg_logden_obs`
- `get_hybridcase_par_templates`
- `get_hybridcase_transforms`
- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
For a specific prob, provide functions that specify details
- `get_hybridproblem_MLapplicator`
- `get_hybridproblem_PBmodel`
- `get_hybridproblem_neg_logden_obs`
- `get_hybridproblem_par_templates`
- `get_hybridproblem_transforms`
- `get_hybridproblem_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
optionally
- `gen_hybridcase_synthetic`
- `get_hybridcase_n_covar` (defaults to number of rows in xM in train_dataloader )
- `get_hybridcase_float_type` (defaults to `eltype(θM)`)
- `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
- `get_hybridproblem_n_covar` (defaults to number of rows in xM in train_dataloader )
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
- `get_hybridproblem_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
"""
abstract type AbstractHybridCase end;
abstract type AbstractHybridProblem end;


"""
get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase; scenario=())
get_hybridproblem_MLapplicator([rng::AbstractRNG,] ::AbstractHybridProblem; scenario=())

Construct the machine learning model fro given problem case and ML-Framework and
Construct the machine learning model fro given problem prob and ML-Framework and
scenario.

returns a Tuple of
- AbstractModelApplicator
- initial parameter vector
"""
function get_hybridcase_MLapplicator end
function get_hybridproblem_MLapplicator end

function get_hybridcase_MLapplicator(case::AbstractHybridCase; scenario=())
get_hybridcase_MLapplicator(Random.default_rng(), case; scenario)
function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario=())
get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario)
end

"""
get_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=())
get_hybridproblem_PBmodel(::AbstractHybridProblem; scenario::NTuple=())

Construct the process-based model function
`f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)`
Expand All @@ -48,59 +48,59 @@
- first, those that are constant across sites
- second, those that vary across sites, with a column for each site
"""
function get_hybridcase_PBmodel end
function get_hybridproblem_PBmodel end

"""
get_hybridcase_neg_logden_obs(::AbstractHybridCase; scenario)
get_hybridproblem_neg_logden_obs(::AbstractHybridProblem; scenario)

Provide a `function(y_obs, ypred) -> Real` that computes the negative logdensity
of the observations, given the predictions.
"""
function get_hybridcase_neg_logden_obs end
function get_hybridproblem_neg_logden_obs end


"""
get_hybridcase_par_templates(::AbstractHybridCase; scenario)
get_hybridproblem_par_templates(::AbstractHybridProblem; scenario)

Provide tuple of templates of ComponentVectors `θP` and `θM`.
"""
function get_hybridcase_par_templates end
function get_hybridproblem_par_templates end


"""
get_hybridcase_transforms(::AbstractHybridCase; scenario)
get_hybridproblem_transforms(::AbstractHybridProblem; scenario)

Return a NamedTupe of
- `transP`: Bijectors.Transform for the global PBM parameters, θP
- `transM`: Bijectors.Transform for the single-site PBM parameters, θM
"""
function get_hybridcase_transforms end
function get_hybridproblem_transforms end

# """
# get_hybridcase_par_templates(::AbstractHybridCase; scenario)
# get_hybridproblem_par_templates(::AbstractHybridProblem; scenario)
# Provide a NamedTuple of number of
# - n_covar: covariates xM
# - n_site: all sites in the data
# - n_batch: sites in one minibatch during fitting
# - n_θM, n_θP: entries in parameter vectors
# """
# function get_hybridcase_sizes end
# function get_hybridproblem_sizes end

"""
get_hybridcase_n_covar(::AbstractHybridCase; scenario)
get_hybridproblem_n_covar(::AbstractHybridProblem; scenario)

Provide the number of covariates. Default returns the number of rows in `xM` from
`get_hybridcase_train_dataloader`.
`get_hybridproblem_train_dataloader`.
"""
function get_hybridcase_n_covar(case::AbstractHybridCase; scenario)
train_loader = get_hybridcase_train_dataloader(Random.default_rng(), case; scenario)
function get_hybridproblem_n_covar(prob::AbstractHybridProblem; scenario)
train_loader = get_hybridproblem_train_dataloader(Random.default_rng(), prob; scenario)
(xM, xP, y_o, y_unc) = first(train_loader)
n_covar = size(xM, 1)
return(n_covar)
end

"""
gen_hybridcase_synthetic([rng,] ::AbstractHybridCase; scenario)
gen_hybridcase_synthetic([rng,] ::AbstractHybridProblem; scenario)

Setup synthetic data, a NamedTuple of
- xM: matrix of covariates, with one column per site
Expand All @@ -114,40 +114,40 @@
function gen_hybridcase_synthetic end

"""
get_hybridcase_float_type(::AbstractHybridCase; scenario)
get_hybridproblem_float_type(::AbstractHybridProblem; scenario)

Determine the FloatType for given Case and scenario, defaults to Float32
"""
function get_hybridcase_float_type(case::AbstractHybridCase; scenario=())
return eltype(get_hybridcase_par_templates(case; scenario).θM)
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario=())
return eltype(get_hybridproblem_par_templates(prob; scenario).θM)
end

"""
get_hybridcase_train_dataloader([rng,] ::AbstractHybridCase; scenario)
get_hybridproblem_train_dataloader([rng,] ::AbstractHybridProblem; scenario)

Return a DataLoader that provides a tuple of
- `xM`: matrix of covariates, with one column per site
- `xP`: Iterator of process-model drivers, with one element per site
- `y_o`: matrix of observations with added noise, with one column per site
- `y_unc`: matrix `sizeof(y_o)` of uncertainty information
"""
function get_hybridcase_train_dataloader(rng::AbstractRNG, case::AbstractHybridCase;
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::AbstractHybridProblem;
scenario = ())
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, case; scenario)
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, prob; scenario)
n_batch = 10
xM_gpu = :use_Flux ∈ scenario ? CuArray(xM) : xM
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
return(train_loader)
end

function get_hybridcase_train_dataloader(case::AbstractHybridCase; scenario = ())
function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ())

Check warning on line 143 in src/AbstractHybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractHybridProblem.jl#L143

Added line #L143 was not covered by tests
rng::AbstractRNG = Random.default_rng()
get_hybridcase_train_dataloader(rng, case; scenario)
get_hybridproblem_train_dataloader(rng, prob; scenario)

Check warning on line 145 in src/AbstractHybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractHybridProblem.jl#L145

Added line #L145 was not covered by tests
end


"""
get_hybridcase_cor_starts(case::AbstractHybridCase; scenario)
get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario)

Specify blocks in correlation matrices among parameters.
Returns a NamedTuple.
Expand All @@ -163,7 +163,7 @@
If there is only single block of all ML-predicted parameters being correlated
with each other then this block starts at position 1: `(P=(1,3), M=(1,))`.
"""
function get_hybridcase_cor_starts(case::AbstractHybridCase; scenario = ())
function get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario = ())

Check warning on line 166 in src/AbstractHybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractHybridProblem.jl#L166

Added line #L166 was not covered by tests
(P=(1,), M=(1,))
end

Expand Down
2 changes: 1 addition & 1 deletion src/ComponentArrayInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ function ComponentArrayInterpreter(
ComponentArrayInterpreter(axes_ext)
end

# ambuiguity with two empty Tuples (edge case that does not make sense)
# ambuiguity with two empty Tuples (edge prob that does not make sense)
# Empty ComponentVector with no other array dimensions -> empty componentVector
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
ComponentArrayInterpreter(CA.ComponentVector())
Expand Down
26 changes: 13 additions & 13 deletions src/DoubleMM/f_doubleMM.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct DoubleMMCase <: AbstractHybridCase end
struct DoubleMMCase <: AbstractHybridProblem end


θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
Expand All @@ -18,19 +18,19 @@
return (y)
end

function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ())
(; θP, θM)
end

function HVI.get_hybridcase_transforms(::DoubleMMCase; scenario::NTuple = ())
function HVI.get_hybridproblem_transforms(::DoubleMMCase; scenario::NTuple = ())
(; transP, transM)
end

function HVI.get_hybridcase_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ())
function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ())

Check warning on line 29 in src/DoubleMM/f_doubleMM.jl

View check run for this annotation

Codecov / codecov/patch

src/DoubleMM/f_doubleMM.jl#L29

Added line #L29 was not covered by tests
neg_logden_indep_normal
end

# function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
# function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario = ())
# n_covar_pc = 2
# n_covar = n_covar_pc + 3 # linear dependent
# #n_site = 10^n_covar_pc
Expand All @@ -41,7 +41,7 @@
# (; n_covar, n_batch, n_θM, n_θP)
# end

function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
function HVI.get_hybridproblem_PBmodel(::DoubleMMCase; scenario::NTuple = ())
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
pred_sites = applyf(f_doubleMM, θMs, θP, x)
Expand All @@ -50,26 +50,26 @@
end
end

# function HVI.get_hybridcase_float_type(::DoubleMMCase; scenario)
# function HVI.get_hybridproblem_float_type(::DoubleMMCase; scenario)
# return Float32
# end

const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]

function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase;
function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
scenario = ())
n_covar_pc = 2
n_site = 200
n_covar = 5
n_θM = length(θM)
FloatType = get_hybridcase_float_type(case; scenario)
FloatType = get_hybridproblem_float_type(prob; scenario)
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
rhodec = 8, is_using_dropout = false)
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
# normalize to be distributed around the prescribed true values
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
f = get_hybridcase_PBmodel(case; scenario)
f = get_hybridproblem_PBmodel(prob; scenario)
xP = fill((;S1=xP_S1, S2=xP_S2), n_site)
y_global_true, y_true = f(θP, θMs_true, xP)
σ_o = FloatType(0.01)
Expand All @@ -91,10 +91,10 @@
)
end

function HVI.get_hybridcase_MLapplicator(
rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase; scenario = ())
function HVI.get_hybridproblem_MLapplicator(
rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ())
ml_engine = select_ml_engine(; scenario)
construct_3layer_MLApplicator(rng, case, ml_engine; scenario)
construct_3layer_MLApplicator(rng, prob, ml_engine; scenario)
end


Expand Down
Loading