diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index cc70fa2..74ce579 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -16,24 +16,24 @@ using OptimizationOptimisers using Bijectors using UnicodePlots -const case = DoubleMM.DoubleMMCase() +const prob = DoubleMM.DoubleMMCase() scenario = (:default,) rng = StableRNG(111) -par_templates = get_hybridcase_par_templates(case; scenario) +par_templates = get_hybridproblem_par_templates(prob; scenario) -#n_covar = get_hybridcase_n_covar(case; scenario) -#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +#n_covar = get_hybridproblem_n_covar(prob; scenario) +#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc -) = gen_hybridcase_synthetic(rng, case; scenario); +) = gen_hybridcase_synthetic(rng, prob; scenario); n_covar = size(xM,1) #----- fit g to θMs_true -g, ϕg0 = get_hybridcase_MLapplicator(case; scenario); -(; transP, transM) = get_hybridcase_transforms(case; scenario) +g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario); +(; transP, transM) = get_hybridproblem_transforms(prob; scenario) function loss_g(ϕg, x, g, transM) ζMs = g(x, ϕg) # predict the log of the parameters @@ -52,8 +52,8 @@ res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), max l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM) scatterplot(vec(θMs_true), vec(θMs_pred)) -f = get_hybridcase_PBmodel(case; scenario) -py = get_hybridcase_neg_logden_obs(case; scenario) +f = get_hybridproblem_PBmodel(prob; scenario) +py = get_hybridproblem_neg_logden_obs(prob; scenario) #----------- fit g and θP to y_o () -> begin @@ -85,8 +85,8 @@ end #---------- HVI n_MC = 3 -(; transP, transM) = get_hybridcase_transforms(case; scenario) -FT = get_hybridcase_float_type(case; scenario) +(; transP, transM) = get_hybridproblem_transforms(prob; scenario) +FT = get_hybridproblem_float_type(prob; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM); @@ -167,7 +167,7 @@ mean_σ_o_MC = 0.006042 ϕ = CA.getdata(ϕ_ini) |> Flux.gpu; xM_gpu = xM |> Flux.gpu; scenario_flux = (scenario..., :use_Flux) -g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux); +g_flux, _ = get_hybridproblem_MLapplicator(prob; scenario = scenario_flux); # otpimize using LUX () -> begin @@ -205,7 +205,7 @@ gr = Zygote.gradient(fcost, gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...) train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) -#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux)) +#train_loader = get_hybridproblem_train_dataloader(prob, rng; scenario = (scenario..., :use_Flux)) optf = Optimization.OptimizationFunction( (ϕ, data) -> begin diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index 02b9fa4..0c92933 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -36,13 +36,13 @@ end # end function HVI.construct_3layer_MLApplicator( - rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux}; + rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:Flux}; scenario::NTuple = ()) - (;θM) = get_hybridcase_par_templates(case; scenario) + (;θM) = get_hybridproblem_par_templates(prob; scenario) n_out = length(θM) - n_covar = get_hybridcase_n_covar(case; scenario) - #(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario) - float_type = get_hybridcase_float_type(case; scenario) + n_covar = get_hybridproblem_n_covar(prob; scenario) + #(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario) + float_type = get_hybridproblem_float_type(prob; scenario) is_using_dropout = :use_dropout ∈ scenario is_using_dropout && error("dropout scenario not supported with Flux yet.") g_chain = Flux.Chain( diff --git a/ext/HybridVariationalInferenceSimpleChainsExt.jl b/ext/HybridVariationalInferenceSimpleChainsExt.jl index df1a122..03ecd42 100644 --- a/ext/HybridVariationalInferenceSimpleChainsExt.jl +++ b/ext/HybridVariationalInferenceSimpleChainsExt.jl @@ -20,11 +20,11 @@ end HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ) function HVI.construct_3layer_MLApplicator( - rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains}; + rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains}; scenario::NTuple = ()) - n_covar = get_hybridcase_n_covar(case; scenario) - FloatType = get_hybridcase_float_type(case; scenario) - (;θM) = get_hybridcase_par_templates(case; scenario) + n_covar = get_hybridproblem_n_covar(prob; scenario) + FloatType = get_hybridproblem_float_type(prob; scenario) + (;θM) = get_hybridproblem_par_templates(prob; scenario) n_out = length(θM) is_using_dropout = :use_dropout ∈ scenario g_chain = if is_using_dropout diff --git a/src/hybrid_case.jl b/src/AbstractHybridProblem.jl similarity index 56% rename from src/hybrid_case.jl rename to src/AbstractHybridProblem.jl index 585a266..f81b5c3 100644 --- a/src/hybrid_case.jl +++ b/src/AbstractHybridProblem.jl @@ -2,40 +2,40 @@ Type to dispatch constructing data and network structures for different cases of hybrid problem setups -For a specific case, provide functions that specify details -- `get_hybridcase_MLapplicator` -- `get_hybridcase_PBmodel` -- `get_hybridcase_neg_logden_obs` -- `get_hybridcase_par_templates` -- `get_hybridcase_transforms` -- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`) +For a specific prob, provide functions that specify details +- `get_hybridproblem_MLapplicator` +- `get_hybridproblem_PBmodel` +- `get_hybridproblem_neg_logden_obs` +- `get_hybridproblem_par_templates` +- `get_hybridproblem_transforms` +- `get_hybridproblem_train_dataloader` (default depends on `gen_hybridcase_synthetic`) optionally - `gen_hybridcase_synthetic` -- `get_hybridcase_n_covar` (defaults to number of rows in xM in train_dataloader ) -- `get_hybridcase_float_type` (defaults to `eltype(θM)`) -- `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`) +- `get_hybridproblem_n_covar` (defaults to number of rows in xM in train_dataloader ) +- `get_hybridproblem_float_type` (defaults to `eltype(θM)`) +- `get_hybridproblem_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`) """ -abstract type AbstractHybridCase end; +abstract type AbstractHybridProblem end; """ - get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase; scenario=()) + get_hybridproblem_MLapplicator([rng::AbstractRNG,] ::AbstractHybridProblem; scenario=()) -Construct the machine learning model fro given problem case and ML-Framework and +Construct the machine learning model fro given problem prob and ML-Framework and scenario. returns a Tuple of - AbstractModelApplicator - initial parameter vector """ -function get_hybridcase_MLapplicator end +function get_hybridproblem_MLapplicator end -function get_hybridcase_MLapplicator(case::AbstractHybridCase; scenario=()) - get_hybridcase_MLapplicator(Random.default_rng(), case; scenario) +function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario=()) + get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario) end """ - get_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=()) + get_hybridproblem_PBmodel(::AbstractHybridProblem; scenario::NTuple=()) Construct the process-based model function `f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)` @@ -48,59 +48,59 @@ returns a tuple of predictions with components - first, those that are constant across sites - second, those that vary across sites, with a column for each site """ -function get_hybridcase_PBmodel end +function get_hybridproblem_PBmodel end """ - get_hybridcase_neg_logden_obs(::AbstractHybridCase; scenario) + get_hybridproblem_neg_logden_obs(::AbstractHybridProblem; scenario) Provide a `function(y_obs, ypred) -> Real` that computes the negative logdensity of the observations, given the predictions. """ -function get_hybridcase_neg_logden_obs end +function get_hybridproblem_neg_logden_obs end """ - get_hybridcase_par_templates(::AbstractHybridCase; scenario) + get_hybridproblem_par_templates(::AbstractHybridProblem; scenario) Provide tuple of templates of ComponentVectors `θP` and `θM`. """ -function get_hybridcase_par_templates end +function get_hybridproblem_par_templates end """ - get_hybridcase_transforms(::AbstractHybridCase; scenario) + get_hybridproblem_transforms(::AbstractHybridProblem; scenario) Return a NamedTupe of - `transP`: Bijectors.Transform for the global PBM parameters, θP - `transM`: Bijectors.Transform for the single-site PBM parameters, θM """ -function get_hybridcase_transforms end +function get_hybridproblem_transforms end # """ -# get_hybridcase_par_templates(::AbstractHybridCase; scenario) +# get_hybridproblem_par_templates(::AbstractHybridProblem; scenario) # Provide a NamedTuple of number of # - n_covar: covariates xM # - n_site: all sites in the data # - n_batch: sites in one minibatch during fitting # - n_θM, n_θP: entries in parameter vectors # """ -# function get_hybridcase_sizes end +# function get_hybridproblem_sizes end """ - get_hybridcase_n_covar(::AbstractHybridCase; scenario) + get_hybridproblem_n_covar(::AbstractHybridProblem; scenario) Provide the number of covariates. Default returns the number of rows in `xM` from -`get_hybridcase_train_dataloader`. +`get_hybridproblem_train_dataloader`. """ -function get_hybridcase_n_covar(case::AbstractHybridCase; scenario) - train_loader = get_hybridcase_train_dataloader(Random.default_rng(), case; scenario) +function get_hybridproblem_n_covar(prob::AbstractHybridProblem; scenario) + train_loader = get_hybridproblem_train_dataloader(Random.default_rng(), prob; scenario) (xM, xP, y_o, y_unc) = first(train_loader) n_covar = size(xM, 1) return(n_covar) end """ - gen_hybridcase_synthetic([rng,] ::AbstractHybridCase; scenario) + gen_hybridcase_synthetic([rng,] ::AbstractHybridProblem; scenario) Setup synthetic data, a NamedTuple of - xM: matrix of covariates, with one column per site @@ -114,16 +114,16 @@ Setup synthetic data, a NamedTuple of function gen_hybridcase_synthetic end """ - get_hybridcase_float_type(::AbstractHybridCase; scenario) + get_hybridproblem_float_type(::AbstractHybridProblem; scenario) Determine the FloatType for given Case and scenario, defaults to Float32 """ -function get_hybridcase_float_type(case::AbstractHybridCase; scenario=()) - return eltype(get_hybridcase_par_templates(case; scenario).θM) +function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario=()) + return eltype(get_hybridproblem_par_templates(prob; scenario).θM) end """ - get_hybridcase_train_dataloader([rng,] ::AbstractHybridCase; scenario) + get_hybridproblem_train_dataloader([rng,] ::AbstractHybridProblem; scenario) Return a DataLoader that provides a tuple of - `xM`: matrix of covariates, with one column per site @@ -131,23 +131,23 @@ Return a DataLoader that provides a tuple of - `y_o`: matrix of observations with added noise, with one column per site - `y_unc`: matrix `sizeof(y_o)` of uncertainty information """ -function get_hybridcase_train_dataloader(rng::AbstractRNG, case::AbstractHybridCase; +function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::AbstractHybridProblem; scenario = ()) - (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, case; scenario) + (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, prob; scenario) n_batch = 10 xM_gpu = :use_Flux ∈ scenario ? CuArray(xM) : xM train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) return(train_loader) end -function get_hybridcase_train_dataloader(case::AbstractHybridCase; scenario = ()) +function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ()) rng::AbstractRNG = Random.default_rng() - get_hybridcase_train_dataloader(rng, case; scenario) + get_hybridproblem_train_dataloader(rng, prob; scenario) end """ - get_hybridcase_cor_starts(case::AbstractHybridCase; scenario) + get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario) Specify blocks in correlation matrices among parameters. Returns a NamedTuple. @@ -163,7 +163,7 @@ then the first subrange starts at position 1 and the second subrange starts at p If there is only single block of all ML-predicted parameters being correlated with each other then this block starts at position 1: `(P=(1,3), M=(1,))`. """ -function get_hybridcase_cor_starts(case::AbstractHybridCase; scenario = ()) +function get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario = ()) (P=(1,), M=(1,)) end diff --git a/src/ComponentArrayInterpreter.jl b/src/ComponentArrayInterpreter.jl index 7887615..46cd5f2 100644 --- a/src/ComponentArrayInterpreter.jl +++ b/src/ComponentArrayInterpreter.jl @@ -137,7 +137,7 @@ function ComponentArrayInterpreter( ComponentArrayInterpreter(axes_ext) end -# ambuiguity with two empty Tuples (edge case that does not make sense) +# ambuiguity with two empty Tuples (edge prob that does not make sense) # Empty ComponentVector with no other array dimensions -> empty componentVector function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{}) ComponentArrayInterpreter(CA.ComponentVector()) diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index ee6c21f..008eceb 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -1,4 +1,4 @@ -struct DoubleMMCase <: AbstractHybridCase end +struct DoubleMMCase <: AbstractHybridProblem end θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0) @@ -18,19 +18,19 @@ function f_doubleMM(θ::AbstractVector, x) return (y) end -function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ()) (; θP, θM) end -function HVI.get_hybridcase_transforms(::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridproblem_transforms(::DoubleMMCase; scenario::NTuple = ()) (; transP, transM) end -function HVI.get_hybridcase_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ()) neg_logden_indep_normal end -# function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ()) +# function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario = ()) # n_covar_pc = 2 # n_covar = n_covar_pc + 3 # linear dependent # #n_site = 10^n_covar_pc @@ -41,7 +41,7 @@ end # (; n_covar, n_batch, n_θM, n_θP) # end -function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridproblem_PBmodel(::DoubleMMCase; scenario::NTuple = ()) #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x) pred_sites = applyf(f_doubleMM, θMs, θP, x) @@ -50,26 +50,26 @@ function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ()) end end -# function HVI.get_hybridcase_float_type(::DoubleMMCase; scenario) +# function HVI.get_hybridproblem_float_type(::DoubleMMCase; scenario) # return Float32 # end const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1] const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0] -function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase; +function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, prob::DoubleMMCase; scenario = ()) n_covar_pc = 2 n_site = 200 n_covar = 5 n_θM = length(θM) - FloatType = get_hybridcase_float_type(case; scenario) + FloatType = get_hybridproblem_float_type(prob; scenario) xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM; rhodec = 8, is_using_dropout = false) int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,)) # normalize to be distributed around the prescribed true values θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1))) - f = get_hybridcase_PBmodel(case; scenario) + f = get_hybridproblem_PBmodel(prob; scenario) xP = fill((;S1=xP_S1, S2=xP_S2), n_site) y_global_true, y_true = f(θP, θMs_true, xP) σ_o = FloatType(0.01) @@ -91,10 +91,10 @@ function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase; ) end -function HVI.get_hybridcase_MLapplicator( - rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase; scenario = ()) +function HVI.get_hybridproblem_MLapplicator( + rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ()) ml_engine = select_ml_engine(; scenario) - construct_3layer_MLApplicator(rng, case, ml_engine; scenario) + construct_3layer_MLApplicator(rng, prob, ml_engine; scenario) end diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 9f89801..c1f5b8b 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -1,4 +1,4 @@ -struct HybridProblem <: AbstractHybridCase +struct HybridProblem <: AbstractHybridProblem θP θM f @@ -32,41 +32,41 @@ function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, HybridProblem(θP, θM, g, ϕg, f, args...; kwargs...) end -function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ()) (; θP = prob.θP, θM = prob.θM) end -function get_hybridcase_neg_logden_obs(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_neg_logden_obs(prob::HybridProblem; scenario::NTuple = ()) prob.py end -function get_hybridcase_transforms(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_transforms(prob::HybridProblem; scenario::NTuple = ()) (; transP = prob.transP, transM = prob.transM) end -# function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ()) +# function get_hybridproblem_sizes(prob::HybridProblem; scenario::NTuple = ()) # n_θM = length(prob.θM) # n_θP = length(prob.θP) # (; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP) # end -function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_PBmodel(prob::HybridProblem; scenario::NTuple = ()) prob.f end -function get_hybridcase_MLapplicator(prob::HybridProblem; scenario::NTuple = ()); +function get_hybridproblem_MLapplicator(prob::HybridProblem; scenario::NTuple = ()); prob.g, prob.ϕg end -function get_hybridcase_train_dataloader(rng::AbstractRNG, prob::HybridProblem; scenario = ()) +function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::HybridProblem; scenario = ()) return(prob.train_loader) end -function get_hybridcase_cor_starts(prob::HybridProblem; scenario = ()) +function get_hybridproblem_cor_starts(prob::HybridProblem; scenario = ()) prob.cor_starts end -# function get_hybridcase_float_type(prob::HybridProblem; scenario::NTuple = ()) +# function get_hybridproblem_float_type(prob::HybridProblem; scenario::NTuple = ()) # eltype(prob.θM) # end diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 6a56427..f81c6ef 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -23,13 +23,13 @@ include("ModelApplicator.jl") export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler include("GPUDataHandler.jl") -export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, - get_hybridcase_float_type, gen_hybridcase_synthetic, - get_hybridcase_par_templates, get_hybridcase_transforms, get_hybridcase_train_dataloader, - get_hybridcase_neg_logden_obs, - get_hybridcase_n_covar, +export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_PBmodel, + get_hybridproblem_float_type, gen_hybridcase_synthetic, + get_hybridproblem_par_templates, get_hybridproblem_transforms, get_hybridproblem_train_dataloader, + get_hybridproblem_neg_logden_obs, + get_hybridproblem_n_covar, gen_cov_pred -include("hybrid_case.jl") +include("AbstractHybridProblem.jl") export HybridProblem include("HybridProblem.jl") diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index de14697..9c30d47 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -38,7 +38,7 @@ end """ construct_3layer_MLApplicator( - rng::AbstractRNG, case::HVI.AbstractHybridCase, ; + rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ; scenario::NTuple = ()) `ml_engine` usually is of type `Val{Symbol}`, e.g. Val(:Flux). See `select_ml_engine`. diff --git a/src/init_hybrid_params.jl b/src/init_hybrid_params.jl index 7480399..8d33fbd 100644 --- a/src/init_hybrid_params.jl +++ b/src/init_hybrid_params.jl @@ -12,7 +12,7 @@ Returns a NamedTuple of # Arguments - `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters -- `ϕg`: vector of parameters to optimize, as returned by `get_hybridcase_MLapplicator` +- `ϕg`: vector of parameters to optimize, as returned by `get_hybridproblem_MLapplicator` - `n_batch`: the number of sites to predicted in each mini-batch - `transP`, `transM`: the Bijector.Transformations for the global and site-dependent parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`. diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 1430cde..380c65a 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -61,12 +61,12 @@ scenario = (:default,) @testset "loss_gf" begin #----------- fit g and θP to y_o rng = StableRNG(111) - g, ϕg0 = get_hybridcase_MLapplicator(prob; scenario) - train_loader = get_hybridcase_train_dataloader(rng, prob; scenario) + g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) + train_loader = get_hybridproblem_train_dataloader(rng, prob; scenario) (xM, xP, y_o, y_unc) = first(train_loader) - f = get_hybridcase_PBmodel(prob; scenario) - par_templates = get_hybridcase_par_templates(prob; scenario) - (;transM, transP) = get_hybridcase_transforms(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario) + par_templates = get_hybridproblem_par_templates(prob; scenario) + (;transM, transP) = get_hybridproblem_transforms(prob; scenario) int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( ϕg=1:length(ϕg0), θP=par_templates.θP)) @@ -99,20 +99,20 @@ import Flux @testset "neg_elbo_transnorm_gf cpu" begin rng = StableRNG(111) - g, ϕg0 = get_hybridcase_MLapplicator(prob) - train_loader = get_hybridcase_train_dataloader(rng, prob) + g, ϕg0 = get_hybridproblem_MLapplicator(prob) + train_loader = get_hybridproblem_train_dataloader(rng, prob) (xM, xP, y_o, y_unc) = first(train_loader) n_batch = size(y_o, 2) - f = get_hybridcase_PBmodel(prob) - (θP0, θM0) = get_hybridcase_par_templates(prob) - (; transP, transM) = get_hybridcase_transforms(prob) - py = get_hybridcase_neg_logden_obs(prob) + f = get_hybridproblem_PBmodel(prob) + (θP0, θM0) = get_hybridproblem_par_templates(prob) + (; transP, transM) = get_hybridproblem_transforms(prob) + py = get_hybridproblem_neg_logden_obs(prob) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( θP0, θM0, ϕg0, n_batch; transP, transM) ϕ_ini = ϕ - py = get_hybridcase_neg_logden_obs(prob) + py = get_hybridproblem_neg_logden_obs(prob) cost = neg_elbo_transnorm_gf(rng, ϕ_ini, g, transPMs_batch, f, py, xM, xP, y_o, y_unc, map(get_concrete, interpreters); @@ -152,7 +152,7 @@ import Flux n_MC=8), ϕ) @test gr[1] isa CuVector - @test eltype(gr[1]) == get_hybridcase_float_type(prob) + @test eltype(gr[1]) == get_hybridproblem_float_type(prob) end end end diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 0c49d78..89cdce3 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -11,14 +11,14 @@ import Zygote using OptimizationOptimisers -const case = DoubleMM.DoubleMMCase() +const prob = DoubleMM.DoubleMMCase() scenario = (:default,) -par_templates = get_hybridcase_par_templates(case; scenario) +par_templates = get_hybridproblem_par_templates(prob; scenario) rng = StableRNG(111) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc -) = gen_hybridcase_synthetic(rng, case; scenario); +) = gen_hybridcase_synthetic(rng, prob; scenario); @testset "gen_hybridcase_synthetic" begin @test isapprox( @@ -28,13 +28,13 @@ rng = StableRNG(111) # test same results for same rng rng2 = StableRNG(111) - gen2 = gen_hybridcase_synthetic(rng2, case; scenario); + gen2 = gen_hybridcase_synthetic(rng2, prob; scenario); @test gen2.y_o == y_o end @testset "loss_g" begin - g, ϕg0 = get_hybridcase_MLapplicator(rng, case; scenario); - (;transP, transM) = get_hybridcase_transforms(case; scenario) + g, ϕg0 = get_hybridproblem_MLapplicator(rng, prob; scenario); + (;transP, transM) = get_hybridproblem_transforms(prob; scenario) function loss_g(ϕg, x, g, transM) # @show first(x,5) @@ -66,9 +66,9 @@ end @testset "loss_gf" begin #----------- fit g and θP to y_o (without uncertainty, without transforming θP) - g, ϕg0 = get_hybridcase_MLapplicator(case; scenario); - (;transP, transM) = get_hybridcase_transforms(case; scenario) - f = get_hybridcase_PBmodel(case; scenario) + g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario); + (;transP, transM) = get_hybridproblem_transforms(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario) int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(ϕg0), θP = par_templates.θP)) @@ -78,8 +78,8 @@ end # Pass the site-data for the batches as separate vectors wrapped in a tuple n_batch = 10 train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch) - # get_hybridcase_train_dataloader recreates synthetic data different θ_true - #train_loader = get_hybridcase_train_dataloader(case, rng; scenario) + # get_hybridproblem_train_dataloader recreates synthetic data different θ_true + #train_loader = get_hybridproblem_train_dataloader(prob, rng; scenario) loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP) l1 = loss_gf(p0, first(train_loader)...)[1] diff --git a/test/test_elbo.jl b/test/test_elbo.jl index dc26383..5723007 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -16,25 +16,25 @@ using GPUArraysCore: GPUArraysCore #CUDA.device!(4) rng = StableRNG(111) -const case = DoubleMM.DoubleMMCase() +const prob = DoubleMM.DoubleMMCase() scenario = (:default,) -FT = get_hybridcase_float_type(case; scenario) +FT = get_hybridproblem_float_type(prob; scenario) -#θsite_true = get_hybridcase_par_templates(case; scenario) -g, ϕg0 = get_hybridcase_MLapplicator(case; scenario); -f = get_hybridcase_PBmodel(case; scenario) +#θsite_true = get_hybridproblem_par_templates(prob; scenario) +g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario); +f = get_hybridproblem_PBmodel(prob; scenario) n_covar = 5 n_batch = 10 -n_θM, n_θP = values(map(length, get_hybridcase_par_templates(case; scenario))) +n_θM, n_θP = values(map(length, get_hybridproblem_par_templates(prob; scenario))) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc -) = gen_hybridcase_synthetic(rng, case; scenario); +) = gen_hybridcase_synthetic(rng, prob; scenario); py = neg_logden_indep_normal n_MC = 3 -(; transP, transM) = get_hybridcase_transforms(case; scenario) +(; transP, transM) = get_hybridproblem_transforms(prob; scenario) # transP = elementwise(exp) # transM = Stacked(elementwise(identity), elementwise(exp)) #transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch @@ -60,7 +60,7 @@ using Flux if CUDA.functional() scenario_flux = (scenario..., :use_Flux) - g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case; scenario = scenario_flux) + g_flux, ϕg0_flux_cpu = get_hybridproblem_MLapplicator(prob; scenario = scenario_flux) end if CUDA.functional() diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index b776fe8..062706d 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -15,13 +15,13 @@ using Bijectors #CUDA.device!(4) rng = StableRNG(111) -const case = DoubleMM.DoubleMMCase() +const prob = DoubleMM.DoubleMMCase() scenario = (:default,) -n_θM, n_θP = length.(values(get_hybridcase_par_templates(case; scenario))) +n_θM, n_θP = length.(values(get_hybridproblem_par_templates(prob; scenario))) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o -) = gen_hybridcase_synthetic(rng, case; scenario) +) = gen_hybridcase_synthetic(rng, prob; scenario) # set to 0.02 rather than zero for debugging non-zero correlations ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02