diff --git a/.gitignore b/.gitignore index 9429e18..1010547 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ tmp/ **/tmp.svg dev/intermediate/* dev/tmp.pdf +docs/src/**/*_files/libs +docs/src/**/*.html +docs/src/**/*.ipynb +docs/src/**/*Manifest.toml diff --git a/README.md b/README.md index e336dc4..e3df30f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# HybridVariationalInference +# HybridVariationalInference HVI [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://EarthyScience.github.io/HybridVariationalInference.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://EarthyScience.github.io/HybridVariationalInference.jl/dev/) @@ -6,11 +6,71 @@ [![Coverage](https://codecov.io/gh/EarthyScience/HybridVariationalInference.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/EarthyScience/HybridVariationalInference.jl) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) -Extending Variational Inference (VI), an approximate bayesian inversion method, -to hybrid models, i.e. models that combine mechanistic and machine-learning parts. +Estimating uncertainty in hybrid models, +i.e. models that combine mechanistic and machine-learning parts, +by extending Variational Inference (VI), an approximate bayesian inversion method. + +## Problem + +Consider the case of Parameter learning, a special case of hybrid models, +where a machine learning model, $g_{\phi_g}$, uses known covariates $x_{Mi}$ at site $i$, +to predict a subset of the parameters, $\theta$ of the process based model, $f$. + +The analyst is interested in both, +- the uncertainty of hybrid model predictions, $ŷ$ (predictive posterior), and +- the uncertainty of process-model parameters $\theta$, including their correlations + (posterior) + +For example consider a soil organic matter process-model that predicts carbon stocks for +different sites. We need to parameterize the unknown carbon use efficiency (CUE) of the soil +microbial community that differs by site, but is hypothesized to correlate with climate variables +and pedogenic factors, such as clay content. +We apply a machine learning model to estimate CUE and fit it end-to-end with other +parameters of the process-model to observed carbon stocks. +In addition to the predicted CUE, we are interested in the uncertainty of CUE and its correlation +with other parameters. +We are interested in the entire posterior probability distribution of the model parameters. + +To understand the background of HVI, refer to the [documentation]((https://EarthyScience.github.io/HybridVariationalInference.jl/dev/)). + +## Usage +![image info](./docs/src/hybrid_variational_setup.png) + +In order to apply HVI, the user has to construct a `HybridProblem` object by specifying +- the machine learning model, $g$ +- covariates $X_{Mi}$ for each site, $i$ +- the names of parameters that differs across sites, $\theta_M$, and global parameters + that are the same across sites, $\theta_P$ + - optionally, sub-blocks in the within-site correlation structure of the parameters + - optionally, which global parameters should be provided to $g$ as additional covariates, + to account for correlations between global and site parameters +- the parameter transformations from unconstrained scale to the scale relevant to the process models, $\theta = T(\zeta)$, e.g. for strictly positive parameters specify `exp`. +- the process-model, $f$ +- drivers of the process-model $X_{Pi}$ at each site, $i$ +- the likelihood function of the observations, given the model predictions, $p(y|ŷ, \theta)$ + +Next this problem is passed to a `HybridPosteriorSolver` that fits an approximation +of the posterior. It returns a NamedTuple of +- `ϕ`: the fitted parameters, a ComponentVector with components + - the machine learning model parameters (usually weights), $\phi_g$ + - means of the global parameters, $\phi_P = \mu_{\zeta_P}$ at transformed + unconstrained scale + - additional parameters, $\phi_{unc}$ of the posterior, $q(\zeta)$, such as + coefficients that describe the scaling of variance with magnitude + and coefficients that parameterize the choleski-factor or the correlation matrix. +- `θP`: predicted means of the global parameters, $\theta_P$ +- `resopt`: the original result object of the optimizer (useful for debugging) + +TODO to get +- means of the site parameters for each site +- samples of posterior +- samples of predictive posterior +## Example +TODO + +see test/test_HybridProblem.jl + + + -The model inversion, infers parametric approximations of posterior density -of model parameters, by comparing model outputs to uncertain observations. At -the same time, a machine learning model is fit that predicts parameters of these -approximations by covariates. diff --git a/dev/Project.toml b/dev/Project.toml index 89fb121..2466234 100644 --- a/dev/Project.toml +++ b/dev/Project.toml @@ -1,15 +1,21 @@ [deps] +AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ColorBrewer = "a2cac450-b92f-5266-8821-25eda20663c8" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FigureHelpers = "9ae22f58-2487-4805-bfc5-386577db46c8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 3a9455b..895ea13 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -1,4 +1,5 @@ -using Test +# start from within dev directory mljulia --project +using Test # Pkg.activate("dev"); cd("dev") using HybridVariationalInference using HybridVariationalInference: HybridVariationalInference as HVI using StableRNGs @@ -21,7 +22,9 @@ scenario = Val((:use_Flux, :use_gpu, :omit_r0, :few_sites)) scenario = Val((:use_Flux, :use_gpu, :omit_r0, :few_sites, :covarK2)) scenario = Val((:use_Flux, :use_gpu, :omit_r0, :sites20, :covarK2)) scenario = Val((:use_Flux, :use_gpu, :omit_r0)) -scenario = Val((:use_Flux, :use_gpu, :omit_r0, :covarK2)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :covarK2, :neglect_cor,)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :covarK2, :K1global,)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :covarK2, )) # prob = DoubleMM.DoubleMMCase() gdev = :use_gpu ∈ HVI._val_value(scenario) ? gpu_device() : identity @@ -40,8 +43,9 @@ train_dataloader = MLUtils.DataLoader( batchsize = n_batch, partial = false) σ_o = exp.(y_unc[:, 1] / 2) # assign the train_loader, otherwise it eatch time creates another version of synthetic data -prob0 = HVI.update(prob0_; train_dataloader); +prob0 = HybridProblem(prob0_; train_dataloader) #tmp = HVI.get_hybridproblem_ϕunc(prob0; scenario) +#prob0.covar #------- pointwise hybrid model fit solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01)) @@ -54,7 +58,7 @@ n_epoch = 80 rng, callback = callback_loss(n_batches_in_epoch * 10), maxiters = n_batches_in_epoch * n_epoch); # update the problem with optimized parameters -prob0o = probo; +prob0o = prob1o =probo; y_pred_global, y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true)); # @descend_code_warntype gf(prob0o; scenario) #@usingany UnicodePlots @@ -72,7 +76,7 @@ histogram(vec(y_pred) - vec(y_true)) # predictions centered around y_o (or y_tru solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site) (; ϕ, resopt) = solve(prob0o, solver1; scenario, rng, callback = callback_loss(20), maxiters = 400) - prob1o = HVI.update(prob0o; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP) + prob1o = HybridProblem(prob0o; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP) y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario) scatterplot(θMs_true[1, :], θMs[1, :]) scatterplot(θMs_true[2, :], θMs[2, :]) @@ -86,7 +90,7 @@ end prob2 = prob1o (; ϕ, resopt) = solve(prob2, solver1; scenario, rng, callback = callback_loss(20), maxiters = 600) - prob2o = HVI.update(prob2; ϕg = collect(ϕ.ϕg), θP = ϕ.θP) + prob2o = HybridProblem(prob2; ϕg = collect(ϕ.ϕg), θP = ϕ.θP) y_pred_global, y_pred, θMs = gf(prob2o, xM, xP) prob2o.θP end @@ -118,11 +122,11 @@ end #scatterplot(θMs_true[1,:], θMs[1,:]) scatterplot(θMs_true[2, :], θMs[2, :]) # able to fit θMs[2,:] - prob3 = HVI.update(prob0, ϕg = Array(ϕg_opt1), θP = θP_true) + prob3 = HybridProblem(prob0, ϕg = Array(ϕg_opt1), θP = θP_true) solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site) (; ϕ, resopt) = solve(prob3, solver1; scenario, rng, callback = callback_loss(50), maxiters = 600) - prob3o = HVI.update(prob3; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP) + prob3o = HybridProblem(prob3; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP) y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario) scatterplot(θMs_true[2, :], θMs[2, :]) prob3o.θP @@ -146,36 +150,40 @@ using MLUtils import Zygote using Bijectors -probh = prob0o # start from point optimized to infer uncertainty -#probh = prob1o # start from point optimized to infer uncertainty -#probh = prob0 # start from no information -solver_post = HybridPosteriorSolver(; - alg = OptimizationOptimisers.Adam(0.01), n_MC = 3) -#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200) -n_batches_in_epoch = n_site ÷ n_batch -n_epoch = 40 -(; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario, - rng, callback = callback_loss(n_batches_in_epoch * 5), - maxiters = n_batches_in_epoch * n_epoch, - θmean_quant = 0.05); -#probh.get_train_loader(;n_batch = 50, scenario) -# update the problem with optimized parameters, including uncertainty -prob1o = probo; -n_sample_pred = 400 -#(; θ, y) = predict_hvi(rng, prob1o, xM, xP; scenario, n_sample_pred); -(; y, θsP, θsMs) = predict_hvi(rng, prob1o; scenario, n_sample_pred, is_inferred=Val(true)); -(y1, θsP1, θsMs1) = (y, θsP, θsMs); - -() -> begin # prediction with fitted parameters (should be smaller than mean) - y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario) - scatterplot(θMs_true[1, :], θMs[1, :]) - scatterplot(θMs_true[2, :], θMs[2, :]) - hcat(θP_true, θP) # all parameters overestimated - histogram(vec(y_pred2) - vec(y_true)) # predicts an unsymmytric distribution +solver_post = HybridPosteriorSolver(; alg = OptimizationOptimisers.Adam(0.01), n_MC = 3) + +() -> begin # priors on mean θ + get_hybridproblem_cor_ends(prob0o) + probh = prob0o # start from point optimized to infer uncertainty + #probh = prob1o # start from point optimized to infer uncertainty + #probh = prob0 # start from no information + #solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200) + n_batches_in_epoch = n_site ÷ n_batch + n_epoch = 40 + (; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario, + rng, callback = callback_loss(n_batches_in_epoch * 5), + maxiters = n_batches_in_epoch * n_epoch, + θmean_quant = 0.05); + #probh.get_train_loader(;n_batch = 50, scenario) + # update the problem with optimized parameters, including uncertainty + prob1o = probo; + n_sample_pred = 400 + #(; θ, y) = predict_hvi(rng, prob1o, xM, xP; scenario, n_sample_pred); + (; y, θsP, θsMs) = predict_hvi(rng, prob1o; scenario, n_sample_pred, is_inferred=Val(true)); + (y1, θsP1, θsMs1) = (y, θsP, θsMs); + + () -> begin # prediction with fitted parameters (should be smaller than mean) + y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario) + scatterplot(θMs_true[1, :], θMs[1, :]) + scatterplot(θMs_true[2, :], θMs[2, :]) + hcat(θP_true, θP) # all parameters overestimated + histogram(vec(y_pred2) - vec(y_true)) # predicts an unsymmytric distribution + end end -#----------- continue HVI without strong prior on θmean -prob2 = HVI.update(prob1o); # copy +#----------- HVI without strong prior on θmean +#prob2 = HybridProblem(prob1o); # copy +prob2 = HybridProblem(prob0o); # copy function fstate_ϕunc(state) u = state.u |> cpu #Main.@infiltrate_main @@ -184,31 +192,60 @@ function fstate_ϕunc(state) end n_epoch = 100 #n_epoch = 400 +#n_epoch = 2 (; ϕ, θP, resopt, interpreters, probo) = solve(prob2, - HVI.update(solver_post, n_MC = 12); - #HVI.update(solver_post, n_MC = 30); + HybridProblem(solver_post, n_MC = 12); + #HybridProblem(solver_post, n_MC = 30); scenario, rng, maxiters = n_batches_in_epoch * n_epoch, - callback = HVI.callback_loss_fstate(n_batches_in_epoch*5, fstate_ϕunc)); + #callback = HVI.callback_loss_fstate(n_batches_in_epoch*5, fstate_ϕunc), + callback = callback_loss(n_batches_in_epoch * 5), + ); prob2o = probo; -() -> begin +() -> begin # store and reload optimized problem using JLD2 #fname_probos = "intermediate/probos_$(last(scenario)).jld2" fname_probos = "intermediate/probos800_$(last(HVI._val_value(scenario))).jld2" - JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o)) + @show fname_probos + #JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o)) + JLD2.save(fname_probos, Dict("prob2o" => prob2o)) tmp = JLD2.load(fname_probos) + prob2o = probo = tmp["prob2o"] end -() -> begin # load the non-covar scenario +() -> begin # load the non-covar scenario, and neglect_cor scenario using JLD2 - #fname_probos = "intermediate/probos_$(last(_val_value(scenario))).jld2" - fname_probos = "intermediate/probos800_omit_r0.jld2" - tmp = JLD2.load(fname_probos) + scenario_indep = Val(Tuple(s for s in HVI._val_value(scenario) if s != :covarK2)) + fname_probos_indep = "intermediate/probos800_$(last(HVI._val_value(scenario_indep))).jld2" + #fname_probos = "intermediate/probos800_omit_r0.jld2" + tmp = JLD2.load(fname_probos_indep) + prob2o_indep = tmp["prob2o"] # test predicting correct obs-uncertainty of predictive posterior n_sample_pred = 400 - (; θ, y, entropy_ζ) = predict_hvi(rng, prob2o_indep, xM, xP; scenario, n_sample_pred); - (θ2_indep, y2_indep) = (θ, y) + (; y, θsP, θsMs) = predict_hvi(rng, prob2o_indep; scenario = scenario_indep, n_sample_pred); + (y2_indep, θsP2_indep, θsMs2_indep) = (y, θsP, θsMs); + #θsMs2_indep .- θsMs2 #(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed + # + scenario_neglect_cor = Val((HVI._val_value(scenario)..., :neglect_cor)) + fname_probos_neglect_cor = "intermediate/probos800_$(last(HVI._val_value(scenario_neglect_cor))).jld2" + #fname_probos = "intermediate/probos800_omit_r0.jld2" + tmp = JLD2.load(fname_probos_neglect_cor) + prob2o_neglect_cor = tmp["prob2o"] + # test predicting correct obs-uncertainty of predictive posterior + n_sample_pred = 400 + (; y, θsP, θsMs) = predict_hvi(rng, prob2o_neglect_cor; scenario = scenario_neglect_cor, n_sample_pred); + (y2_neglect_cor, θsP2_neglect_cor, θsMs2_neglect_cor) = (y, θsP, θsMs); + # + scenario_K1global = Val((HVI._val_value(scenario)..., :K1global)) + fname_probos_K1global = "intermediate/probos800_$(last(HVI._val_value(scenario_K1global))).jld2" + #fname_probos = "intermediate/probos800_omit_r0.jld2" + tmp = JLD2.load(fname_probos_K1global) + prob2o_K1global = tmp["prob2o"] + # test predicting correct obs-uncertainty of predictive posterior + n_sample_pred = 400 + (; y, θsP, θsMs) = predict_hvi(rng, prob2o_K1global; scenario = scenario_K1global, n_sample_pred); + (y2_K1global, θsP2_K1global, θsMs2_K1global) = (y, θsP, θsMs); end () -> begin # otpimize using LUX @@ -286,10 +323,10 @@ histogram(θsP) # (; ϕ, θP, resopt, interpreters) = solve(prob1o, solver_MC; scenario, # rng, callback = callback_loss(n_batches_in_epoch), maxiters = 14); # resopt.objective - # probo = prob3o = HVI.update(prob2; ϕg = cpu_ca(ϕ).ϕg, θP = θP, ϕunc = cpu_ca(ϕ).unc) + # probo = prob3o = HybridProblem(prob2; ϕg = cpu_ca(ϕ).ϕg, θP = θP, ϕunc = cpu_ca(ϕ).unc) - solver_post2 = HVI.update(solver_post; n_MC = 30) - #solver_post2 = HVI.update(solver_post; n_MC=3) + solver_post2 = HybridPosteriorSolver(solver_post; n_MC = 30) + #solver_post2 = HybridPosteriorSolver(solver_post; n_MC = 3) n_rep = 30 n_batchf = n_site n_batchf = n_site ÷ 10 @@ -316,6 +353,7 @@ histogram(θsP) end () -> begin # look at distribution of parameters, predictions, and likelihood and elob at one site + # compare prob1o (with constraining theta to be near original mean) to unconstrained HVI function predict_site(probo, i_site) (; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probo; scenario, n_sample_pred) y_site = y[:, i_site, :] @@ -337,7 +375,7 @@ end end i_site = 1 (r1s, nLs, ent, y_site) = predict_site(prob2o, i_site) - (r1sc, nLsc, entc, y_sitec) = predict_site(prob1o, i_site) + (r1sc, nLsc, entc, y_sitec) = predict_site(prob1o, i_site) # result from point-solver mean(nLs), mean(nLsc) ent, entc # with larger uncertaintsy (higher entropy) in unconstrained cost much lower @@ -418,7 +456,7 @@ using MCMCChains prior_ζ = fit(Normal, @qp_ll(log(1e-2)), @qp_uu(log(10))) prior_ζn = (n) -> MvNormal(fill(prior_ζ.μ, n), PDiagMat(fill(abs2(prior_ζ.σ), n))) prior_ζn(3) -prob = HVI.update(prob0o); +prob = HybridProblem(prob0o); (; θM, θP) = get_hybridproblem_par_templates(prob; scenario) n_θM, n_θP = length.((θM, θP)) @@ -447,7 +485,7 @@ f = get_hybridproblem_PBmodel(prob; scenario) #TODO specify with transPM # @show ζP # Main.@infiltrate_main # step to second time - y_pred = f(exp.(ζP), exp.(ζMs), xP)[2] # first is global + y_pred = f(exp.(ζP), exp.(ζMs)', xP)[2] # first is global for i_obs in 1:n_obs y[i_obs, :] ~ MvNormal(y_pred[i_obs, :], σ_o[i_obs]) # single value σ instead of variance end @@ -457,7 +495,7 @@ f = get_hybridproblem_PBmodel(prob; scenario) y_pred end -model = fsites(y_o; f, n_θP, n_θM, σ_o) +model = fsites(y_o; f = prob0.f_allsites, n_θP, n_θM, σ_o) # setup transformers and interpreters for forward prediction cor_ends = get_hybridproblem_cor_ends(prob; scenario) @@ -465,7 +503,7 @@ g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) ϕunc0 = get_hybridproblem_ϕunc(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) hpints = HybridProblemInterpreters(prob; scenario) -(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( +(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = HVI.init_hybrid_params( θP, θM, cor_ends, ϕg0, hpints; transP, transM, ϕunc0); intm_PMs_gen = get_int_PMs_site(hpints); @@ -496,6 +534,7 @@ end # takes ~ 25 minutes #n_sample_NUTS = 800 n_sample_NUTS = 2000 +#tmp = sample(model, NUTS(0,0.65), 2, initial_params = ζ0_true .+ 0.001) #chain = sample(model, NUTS(), n_sample_NUTS, initial_params = ζ0_true .+ 0.001) #n_sample_NUTS = 24 n_threads = 8 @@ -511,6 +550,15 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread n_sample_NUTS = size(Array(chain),1) end +() -> begin # load HMC sample for K1global scenario + using JLD2 + fname_K1global = "intermediate/doubleMM_chain_zeta_K1global.jld2" + chain_K1global = load(fname_K1global, "chain"; iotype = IOStream); + ζsP_hmc_K1global = Array(chain_K1global)[:,1:n_θP]' + ζsMst_hmc_K1global = reshape(Array(chain_K1global)[:,(n_θP+1) : end], n_sample_NUTS, n_site, n_θM) + ζsMs_hmc_K1global = permutedims(ζsMst_hmc_K1global, (2,3,1)) +end + #ζi = first(eachrow(Array(chain))) f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true) #ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain))); @@ -521,7 +569,11 @@ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true) ζsMs[:,:,1] # first sample: n_site x n_par ζsMs[:,1,:] # first parameter n_site x n_sample -(; y, θsP, θsMs) = HVI.apply_f_trans(ζsP, ζsMs, f_allsites, xP; transP, transM); +trans_mP=StackedArray(transP, size(ζsP, 2)) +trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3)) +θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) +y = apply_process_model(θsP, θsMs, f, xP) +#(; y, θsP, θsMs) = HVI.apply_f_trans(ζsP, ζsMs, f_allsites, xP; transP, transM); (y_hmc, θsP_hmc, θsMs_hmc) = (; y, θsP, θsMs); @@ -537,14 +589,13 @@ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true) end () -> begin # plot chain - #@usingany TwMakieHelpers, CairoMakie + #@usingany FigureHelpers, CairoMakie # θP and first θMs ch = chain[:,vcat(1:n_θP, n_θP+1, n_θP+n_site+1),:]; fig = plot_chn(ch) save("tmp.svg", fig) end - mean_y_invζ = mean_y_hmc = map(mean, eachslice(y_hmc; dims = (1, 2))); #describe(mean_y_pred - y_o) histogram(vec(mean_y_invζ) - vec(y_true)) # predictions centered around y_o (or y_true) @@ -570,9 +621,27 @@ plt = scatterplot(θMs_true'[:,2], mean_θMs[:,2]); lineplot!(plt, 0, 1) #------------------ compare HVI vs HMC sample +# reload results from run without covars, see above () -> begin # compare against HVI sample - #@usingany AlgebraOfGraphics, TwPrototypes, CairoMakie, DataFrames - makie_config = ppt_MakieConfig() + #@usingany AlgebraOfGraphics, FigureHelpers, TwPrototypes, CairoMakie, DataFrames + #@usingany LaTeXStrings + using AlgebraOfGraphics, CairoMakie, FigureHelpers + using AlgebraOfGraphics, CairoMakie, FigureHelpers + makie_config = ppt_MakieConfig(fontsize=14) # decrease standard font from 18 to 14 + paper_config = paper_MakieConfig() + set_default_AoGTheme!(;makie_config) + + using ColorBrewer: ColorBrewer + # two same colors for hmc and hvi , additional for further unspecified labels + cDark2 = cgrad(ColorBrewer.palette("Dark2",3),3,categorical=true) + #color_methods = vcat([k => col for (k, col) in zip([:hmc, :hvi], cDark2[1:2])], cDark2[3], Makie.wong_colors()[2:end]); + cpal0 = Makie.wong_colors() + color_methods = vcat([k => col for (k, col) in zip([:hmc, :hvi], cpal0[1:2])], cpal0[3:end]); + + function lower_lastdigits(sym::Symbol,n_digit=1) + s = string(sym) + latexstring(s[1:(end-n_digit)] * "_" * s[(end-n_digit+1):end]) + end function get_fig_size(cfg; width2height=golden_ratio, xfac=1.0) cfg = makie_config x_inch = first(cfg.size_inches) * xfac @@ -581,23 +650,35 @@ lineplot!(plt, 0, 1) end ζsP_hvi = log.(θsP2) - ζsP_hvi_indep = log.(θsP2) # TODO rerun and reload replace θsP2 + ζsP_hvi_indep = log.(θsP2_indep) + ζsP_hvi_neglect_cor = log.(θsP2_neglect_cor) + ζsP_hvi_K1global = log.(θsP2_K1global) ζsP_hmc = log.(θsP_hmc) ζsMs_hvi = log.(θsMs2) - ζsMs_hvi_indep = log.(θsMs2) # TODO rerun and reload replace θsMs2 + ζsMs_hvi_indep = log.(θsMs2_indep) + ζsMs_hvi_neglect_cor = log.(θsMs2_neglect_cor) + ζsMs_hvi_K1global = log.(θsMs2_K1global) ζsMs_hmc = log.(θsMs_hmc) # int_pms = interpreters.PMs # par_pos = int_pms(1:length(int_pms)) - i_sites = 1:10 + #i_sites = 1:10 + i_sites = 1:5 #i_sites = 6:10 #i_sites = 11:15 - scen = vcat(fill(:hvi,size(ζsP_hvi,2)),fill(:hmc,size(ζsP_hmc,2)),fill(:hvi_indep,size(ζsP_hvi_indep,2))) + scen = vcat( + fill(:hvi,size(ζsP_hvi,2)), + fill(:hmc,size(ζsP_hmc,2)), + fill(:hvi_indep,size(ζsP_hvi_indep,2)), + fill(:neglect_cor,size(ζsP_hvi_neglect_cor,2)), + ) dfP = mapreduce(vcat, axes(θP_true,1)) do i_par #pos = par_pos.P[i_par] DataFrame( - value = vcat(ζsP_hvi[i_par, :], ζsP_hmc[i_par,:], ζsP_hvi_indep[i_par,:]), - variable = keys(θP_true)[i_par], - site = i_sites[1], + value = vcat( + ζsP_hvi[i_par, :], ζsP_hmc[i_par,:], + ζsP_hvi_indep[i_par,:], ζsP_hvi_neglect_cor[i_par,:]), + variable = lower_lastdigits.(keys(θP_true)[i_par]), + site = "site $(i_sites[1])", Method = scen ) end @@ -608,9 +689,10 @@ lineplot!(plt, 0, 1) value = vcat( ζsMs_hvi[i_site,i_par,:], ζsMs_hmc[i_site,i_par,:], - ζsMs_hvi_indep[i_site,i_par,:]), - variable = keys(θM)[i_par], - site = i_site, + ζsMs_hvi_indep[i_site,i_par,:], + ζsMs_hvi_neglect_cor[i_site,i_par,:],), + variable = lower_lastdigits.(keys(θM)[i_par]), + site = "site $(i_site)", Method = scen ) end @@ -620,8 +702,8 @@ lineplot!(plt, 0, 1) mapreduce(vcat, axes(θP,1)) do i_par DataFrame( value = ζP_true[i_par], - variable = keys(θP)[i_par], - site = i_sites[1], + variable = lower_lastdigits.(keys(θP)[i_par]), + site = "site $(i_sites[1])", Method = :true ) end, @@ -629,29 +711,43 @@ lineplot!(plt, 0, 1) mapreduce(vcat, axes(θM,1)) do i_par DataFrame( value = ζMs_true[i_par, i_site], - variable = keys(θM)[i_par], - site = i_site, + variable = lower_lastdigits.(keys(θM)[i_par]), + site = "site $(i_site)", Method = :true ) end end ) # vcat #cf90 = (x) -> quantile(x, [0.05,0.95]) - plt = (data(subset(df, :Method => ByRow(∈((:hvi,:hmc))))) * mapping(:value=> (x -> x ) => "", color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) + - data(df_true) * mapping(:value) * visual(VLines; color=:blue, linestyle=:dash)) * - mapping(col=:variable => sorter(vcat(keys(θP)..., keys(θM)...)), row = (:site => nonnumeric)) - fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2)); - fg = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,)); - fig - save("tmp.svg", fig) + plot_par_densities = (dfs; makie_config = makie_config) -> begin + plt = (data(dfs) * mapping(:value=> (x -> x ) => "", color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) + + data(df_true) * mapping(:value => "") * visual(VLines; color=:blue, linestyle=:dash)) * + mapping(col=:variable => sorter(lower_lastdigits.(vcat(keys(θP)..., keys(θM)...))), + row = (:site => nonnumeric)) + #mapping(col=:variable, row = (:site => nonnumeric)) + #fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2)); # 10 sites + fig = figure_conf(1.0; makie_config); + ffig = draw!(fig, plt, + facet=(; linkxaxes=:minimal, linkyaxes=:none,), + axis=(xlabelvisible=false,yticklabelsvisible=false), + scales(Color = (; palette = color_methods)), + ); + legend!(fig[length(i_sites),1], ffig, ; tellwidth=false, halign=:left, valign=:bottom , margin=(10, 10, 10, 10)) + fig + end + () -> begin + save_with_config(joinpath(pwd(), "tmp.svg"), fig; makie_config = MakieConfig()) + save_with_config(joinpath(pwd(), "tmp"), fig; makie_config = paper_config) + save_with_config("tmp", fig) # returns path in tmp to click on + end + fig = plot_par_densities(subset(df, :Method => ByRow(∈((:hvi,:hmc))))) save_with_config("intermediate/compare_hmc_hvi_sites_$(last(HVI._val_value(scenario)))", fig; makie_config) - plt = (data(subset(df, :Method => ByRow(∈((:hvi, :hvi_indep))))) * mapping(:value=> (x -> x ) => "", color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) + - data(df_true) * mapping(:value) * visual(VLines; color=:blue, linestyle=:dash)) * - mapping(col=:variable => sorter(vcat(keys(θP)..., keys(θM)...)), row = (:site => nonnumeric)) - fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2)); - fg = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,)); - fig + fig = plot_par_densities(subset(df, :Method => ByRow(∈((:hmc,:neglect_cor))))) + save("tmp.svg", fig) + save_with_config("intermediate/compare_hmc_neglectcor_sites_$(last(HVI._val_value(scenario)))", fig; makie_config) + + fig = plot_par_densities(subset(df, :Method => ByRow(∈((:hvi,:hvi_indep))))) save("tmp.svg", fig) save_with_config("intermediate/compare_hvi_indep_sites_$(last(HVI._val_value(scenario)))", fig; makie_config) @@ -665,17 +761,19 @@ lineplot!(plt, 0, 1) vcat( DataFrame( value = y_hmc[i_obs,i_site,:], - site = i_site, + site = "site $(i_site)", Method = :hmc, variable = :y, i_obs = i_obs, + y_i = latexstring("y_$(i_obs)"), ), DataFrame( value = y_hvi[i_obs,i_site,:], - site = i_site, + site = "site $(i_site)", Method = :hvi, variable = :y, i_obs = i_obs, + y_i = latexstring("y_$(i_obs)"), ) )# vcat end @@ -685,54 +783,158 @@ lineplot!(plt, 0, 1) vcat( DataFrame( value = y_true[i_obs,i_site], - site = i_site, - Method = :truth, + site = "site $(i_site)", + Reference = :truth, variable = :y, i_obs = i_obs, + y_i = latexstring("y_$(i_obs)"), ), DataFrame( value = y_o[i_obs,i_site,:], - site = i_site, - Method = :obs, + site = "site $(i_site)", + Reference = :obs, variable = :y, i_obs = i_obs, + y_i = latexstring("y_$(i_obs)"), ) )# vcat end end + using CategoricalArrays + DataFrames.transform!(dfyt, :Reference => (x -> categorical(string.(x); ordered = true, levels = ["truth", "obs"])) => :Reference) # plt = (data(dfy) * mapping(color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) + - data(dfyt) * mapping(color=:Method) * visual(VLines; linestyle=:dash)) * - #data(dfyt) * mapping(color=:Method, linestyle=:Method) * visual(VLines; linestyle=:dash)) * - mapping(:value=>"", col=:i_obs => nonnumeric, row = :site => nonnumeric) - - fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2)); - f = draw!(fig, plt, + data(dfyt) * mapping(color=:Reference => AlgebraOfGraphics.scale(:Reference)) * visual(VLines; linestyle=:dash)) * + #data(dfyt) * mapping(linestyle=:Reference => AlgebraOfGraphics.scale(:Reference)) * visual(VLines; linestyle=:dash)) * + #data(dfyt) * mapping(color=:Reference => AlgebraOfGraphics.scale(:Reference), + # linestyle= :Reference => AlgebraOfGraphics.scale(:Reference)) * visual(VLines)) * # bug? + mapping(:value=>"", col=:y_i, row = :site) + + #fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2)); + fig = figure_conf(1; makie_config); + ffig = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), - axis=(xlabelvisible=false,yticklabelsvisible=false)); - legend!(fig[1,3], f, ; tellwidth=false, halign=:right, valign=:top) # , margin=(-10, -10, 10, 10) + axis=(xlabelvisible=false,yticklabelsvisible=false), + scales(Color = (; palette = color_methods)), + ); + #legend!(fig[1,3], f, ; tellwidth=false, halign=:right, valign=:top) # , margin=(-10, -10, 10, 10) + legend!(fig[1,4], ffig, ; tellwidth=true, halign=:right, valign=:top) # , margin=(-10, -10, 10, 10) fig save("tmp.svg", fig) - save_with_config("intermediate/compare_hmc_hvi_sites_y_$(last(scenario))", fig; makie_config) + save_with_config("intermediate/compare_hmc_hvi_sites_y_$(last(HVI._val_value(scenario)))", fig; makie_config) # hvi predicts y better, hmc fails for quite a few obs: 3,5,6 - # todo compare mean_predictions + # compare mean_predictions mean_y_hvi = map(mean, eachslice(y_hvi; dims = (1, 2))); + size(y_o) + histogram(vec(mean_y_hvi .- y_o)) +end - +#------------------------------------- correlations ----------------- +""" +Compute standard deviation and correlation for predicted parameters on unconstrained scale. + +## Arguments +_ζsP: n_P x n_pred matrix of draws of predicted cross-sites parameters +_ζsMs: n_site x n_M x n_pred of draws of predicted physical parameters + +returns sdP (n_P), sdMs (n_site x n_M), cor_PMs n_P + (n_M * length(i_sites)) square matrix +""" +function compute_sd_cor_PMs(_ζsP, _ζsMs; i_sites_inspect = [1,2,3]) + mP = mean(_ζsP; dims=2) + residP = _ζsP .- mP + sdP = vec(std(residP; dims=2)) + mMs = mean(_ζsMs; dims=3)[:,:,1] + residMs = _ζsMs .- mMs + sdMs = std(residMs; dims=3)[:,:,1] + residMst = permutedims(residMs[i_sites_inspect,:,:], (2,1,3)) # n_M x n_site x n_pred + residPMst = vcat(residP, + reshape(residMst, size(residMst,1)*size(residMst,2), size(residMst,3))) # n_P x n_pred + corPMs = cor(residPMst') + sdP, sdMs, corPMs +end + +function draw_cor_fig(cor, method; makie_config, par_names) + fig = figure_conf(1.3, 0.8; makie_config); + ax = Axis(fig[1,1], + xticklabelsvisible=false,yticklabelsvisible=true, + xticksvisible=false, yticksvisible=true, + yticks = (axes(par_names,1), par_names), + yreversed = true, + aspect = 1, + title = "Corr. $method") + hm = heatmap!(ax, cor) + rowsize!(fig.layout, 1, Aspect(1, 1)) + Colorbar(fig[1,2], hm ; tellwidth=true, tellheight=false) + fig end () -> begin # inspect correlation of residuals - mean_ζ_hvi = map(mean, eachrow(CA.getdata(ζs_hvi))) - r_hvi = ζs_hvi .- mean_ζ_hvi - cor_hvi = cor(CA.getdata(r_hvi)') - mean_ζ_hmc = map(mean, eachrow(CA.getdata(ζs_hmc))) - r_hmc = ζs_hmc .- mean_ζ_hmc - cor_hmc = cor(CA.getdata(r_hmc)') + # get true + i_sites_inspect = [1,2,3] + sdP_hmc, sdMs_hmc, corPMs_hmc = compute_sd_cor_PMs(ζsP_hmc, ζsMs_hmc; i_sites_inspect) + sdP_hvi, sdMs_hvi, corPMs_hvi = compute_sd_cor_PMs(ζsP_hvi, ζsMs_hvi; i_sites_inspect) + sdP_hvi_indep, sdMs_hvi_indep, corPMs_hvi_indep = compute_sd_cor_PMs( + ζsP_hvi_indep, ζsMs_hvi_indep; i_sites_inspect) + sdP_hvi_neglect_cor, sdMs_hvi_neglect_cor, corPMs_hvi_neglect_cor = compute_sd_cor_PMs( + ζsP_hvi_neglect_cor, ζsMs_hvi_neglect_cor; i_sites_inspect) + sdP_hvi_K1global, sdMs_hvi_K1global, corPMs_hvi_K1global = compute_sd_cor_PMs( + ζsP_hvi_K1global, ζsMs_hvi_K1global; i_sites_inspect) + sdP_hmc_K1global, sdMs_hmc_K1global, corPMs_hmc_K1global = compute_sd_cor_PMs( + ζsP_hmc_K1global, ζsMs_hmc_K1global; i_sites_inspect) + # no correlations of K2(global) ML parameters in inversion? + par_names = vcat(["global $k" for k in keys(θP)], vec(["site $i $k" for k in keys(θM), i in i_sites_inspect])) + par_names_globalK1 = vcat(["global K1" for k in keys(θP)], vec(["site $i $k" for k in (:r1, :K2), i in i_sites_inspect])) + fig = draw_cor_fig(corPMs_hmc, "Hamiltonian Monte Carlo posterior"; makie_config, par_names) + save_with_config("intermediate/cor_hmc_$(last(HVI._val_value(scenario)))", fig; makie_config) + fig = draw_cor_fig(corPMs_hvi, "Hybrid Variational Inference posterior"; makie_config, par_names) + save_with_config("intermediate/cor_hvi_$(last(HVI._val_value(scenario)))", fig; makie_config) + fig = draw_cor_fig(corPMs_hvi_neglect_cor, "HVI neglecting block correlations"; makie_config, par_names) + save_with_config("intermediate/cor_hvi_neglect_cor_$(last(HVI._val_value(scenario)))", fig; makie_config) + save_with_config("tmp", fig; makie_config) # - hcat(cor_hvi[:,1], cor_hmc[:,1]) - # positive correlations of K2(1) in θP with K1(3) in θMs + fig = draw_cor_fig(corPMs_hvi_K1global, "HVI K1 global K2 site-dependent"; makie_config, par_names = par_names_globalK1) + save_with_config("intermediate/cor_hvi_K1global_$(last(HVI._val_value(scenario)))", fig; makie_config) + fig = draw_cor_fig(corPMs_hmc_K1global, "HMC K1 global K2 site-dependent"; makie_config, par_names = par_names_globalK1) + save_with_config("intermediate/cor_hmc_K1global_$(last(HVI._val_value(scenario)))", fig; makie_config) + + + df_sd = reduce(vcat, map(axes(θM,1)) do i_par + vcat(DataFrame( + Method = :hmc, + par = lower_lastdigits.(keys(θM)[i_par]), + sd = sdMs_hmc[:,i_par], + value = ζMs_true'[:,i_par], + ), DataFrame( + Method = :hvi, + par = lower_lastdigits.(keys(θM)[i_par]), + sd = sdMs_hvi[:,i_par], + value = ζMs_true'[:,i_par], + ), DataFrame( + Method = :neglect_cor, + par = lower_lastdigits.(keys(θM)[i_par]), + sd = sdMs_hvi_neglect_cor[:,i_par], + value = ζMs_true'[:,i_par], + ),) + end) + plt = data(df_sd) * mapping(color=:Method=>"", row=:par) * + mapping(:value=>"", :sd => "Predicted Standard Deviation") * visual(Scatter, alpha = 0.5) + + #fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2)); + fig = figure_conf(1;makie_config); + #fig = figure_conf(1;makie_config = paper_config); + ffig = draw!(fig, plt, scales(Color = (; palette = color_methods)); + facet=(; linkxaxes=:none, linkyaxes=:none,), + ) + #legend!(fig[1,1], ffig, ;tellheight=false, tellwidth=false, halign=:right, valign=:top, margin=(10, 10, 10, 10)) + legend!(fig[1,1], ffig, ;tellheight=false, tellwidth=false, halign=:left, valign=:bottom, margin=(10, 10, 10, 10)) + save_with_config("tmp", fig; makie_config) + #save_with_config("tmp", fig; makie_config=paper_config) +end +() -> begin # inspect marginal variance + + end @@ -744,7 +946,7 @@ end prior_θ = Normal(0, 10) prior_θn = (n) -> MvNormal(fill(prior_θ.μ, n), PDiagMat(fill(abs2(prior_θ.σ), n))) prior_θn(3) - prob = HVI.update(prob0o); + prob = HybridProblem(prob0o); (; θM, θP) = get_hybridproblem_par_templates(prob; scenario) n_θM, n_θP = length.((θM, θP)) diff --git a/docs/make.jl b/docs/make.jl index 46caa1c..3a06459 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,9 +1,11 @@ using HybridVariationalInference +import HybridVariationalInference.DoubleMM using Documenter DocMeta.setdocmeta!(HybridVariationalInference, :DocTestSetup, :(using HybridVariationalInference); recursive=true) makedocs(; + #modules=[HybridVariationalInference, HybridVariationalInference.DoubleMM], modules=[HybridVariationalInference], authors="Thomas Wutzler and contributors", sitename="HybridVariationalInference.jl", @@ -14,6 +16,23 @@ makedocs(; ), pages=[ "Home" => "index.md", + "Problem" => "problem.md", + "Tutorials" => [ + "Basic workflow" => "tutorials/basic_cpu.md", + "Inspect results" => "tutorials/inspect_results.md", + #"Test quarto markdown" => "tutorials/test1.md", + ], + "How to" => [ + #".. model independent parameters" => "tutorials/how_to_guides/blocks_corr_site.md", + #".. model site-global corr" => "tutorials/how_to_guides/corr_site_global.md", + ], + "Explanation" => [ + #"Theory" => "explanation/theory_hvi.md", TODO activate when paper is published + ], + "Reference" => [ + "Public" => "reference/reference_public.md", + "Internal" => "reference/reference_internal.md", + ], ], ) diff --git a/docs/src/explanation/hybrid_variational_setup.png b/docs/src/explanation/hybrid_variational_setup.png new file mode 100644 index 0000000..47de951 Binary files /dev/null and b/docs/src/explanation/hybrid_variational_setup.png differ diff --git a/docs/src/explanation/theory_hvi.md b/docs/src/explanation/theory_hvi.md new file mode 100644 index 0000000..18053a7 --- /dev/null +++ b/docs/src/explanation/theory_hvi.md @@ -0,0 +1,172 @@ +# Theory + +## Setup of the Problem +The hybrid variational inferecne, HVI, infers a parametric approximations of +the posterior density, $q(\theta|y)$, +by comparing model outputs to uncertain observations, $y$. +At the same time, the machine learning model, $g$, is fitted, +which predicts a subset, $\phi_M$ of the parameters of the approximation $q$. +In the case where it predicts the marginal means of process-model's +parameters, $\theta$, this corresponds to the same machine learning model, $g$, + as in parameter learning without consideration of uncertainty. + +HVI approximates the +posterior distribution of process-model parameters +as a transformation of a multivariate normally distributed variable $\zeta$: +$q(\theta) = T(q(\zeta))$, $q(\zeta) \sim \mathcal{N}(\mu, \Sigma)$. +This allows for efficient sampling, ensures a finite nonzero probability across +the entire real space, and, at the same, time provides sufficient flexibility. + +![image info](hybrid_variational_setup.png) + +The optimized parameters, $\phi = (\phi_g, \phi_P, \phi_q)$ are the same for each site. +This allows to apply minibatching, which does not require predicting the +full observation vector, $y$, during parameter fitting. + +In order to learn $\phi_g$, the user needs to provide a batch of $i \in \{1 \ldots n_{b}\}$ observation $y_i$, their uncertinaty, $y_{unc,i}$, covariates $x_{Mi}$ and drivers $x_{Pi}$ in each iteration of the optimization. Moreover, for each $i$, HVI needs to draw $n_{MC}$ samples parameters $\zeta_i$, transforms and runs the model to compute a prediction for $y_i$ and computes $\log p(y_i)$ to estimate the expected value occurring in the ELBO (see next section). + +## Estimation using the ELBO + +In order to find the parameters of the approximation of the posterior, HVI +minizes the KL divergence between the approximation and the true posterior. +This is achieve by maximizing the evidence lower bound (ELBO). + +$$\mathcal{L}(\phi) = \mathbb{E}_{q(\theta)} \left[\log p(y,\theta) \right] - \mathbb{E}_{q(\theta)} \left[\log q(\theta) \right]$$ + +The second term is the entropy of the approximating distribution, which has a closed form +for a multivariate normal distribution. +The expectation of the first term can be estimated using Monte-Carlo integration. +When combined with stochastic gradient descent, this needs only a small number of samples. +However, HVI needs to compute the gradient of this expectation of the joint posterior +density of observations and parameter, +$\log p(y,\theta) = \log p(y|\theta) + \log p(\theta)$, +by automatic differentiation. Hence, HVI needs to differentiate the process-model, $f$, +that is run during computation of the Likelihood of the data, $p(y|\theta)$. + +## Parameter transformations +HVI prescribes $q(\theta)$ to be the distribution of a transformed random variable, +$\theta = T^{-1}(\zeta)$, where $\zeta = T(\theta)$ has a multivariate Normal distribution +(MVN) in unconstrained $\mathbb{R}^n$. The transformation, $T$, provides more flexibility +to model the posterior and takes care of the case where the support of $q(\theta)$ is +smaller than $\mathbb{R}^n$, the support of the MVN. For example if the the log of +$\theta$ is normally distributed, then $\theta$ has LogNormal distribution, and +$\theta = T^{-1}(\zeta) \equiv e^{\zeta}$. The transformed joint density then is + +$$p_\zeta(y,\zeta) = p(x, T^{-1}(\zeta)) \, \left| det J_{T^{-1}}(\zeta)\right|,$$ + +where $\left| det J_{T^{-1}}(\zeta)\right|$ denotes the absolute value of the determinant of the Jacobian of the inverse of transformation, $T$ evaluated at $\zeta$. + +With those assumptions, the ELBO becomes + +$$\mathcal{L}(\phi) = \mathbb{E}_{q(\zeta)} \left[ \log p(y, T^{-1}(\zeta)) + \log \left| det J_{T^{-1}}(\zeta)\right| \right] + \mathbb{H}_{q(\zeta)},$$ +where $\mathbb{H}_{q(\zeta)}$ is the entropy of the approximating density and the expectation is across a normally distributed random variable, $\zeta$. + +## Covariance structure + +HVI assumes that transforms of the latent variable follow a multivariate normal distribution: $\zeta = T((\theta_P, \theta_M)) = (\zeta_P, \zeta_M) \sim \mathcal{N}(\mu(\phi_g), \Sigma)$. The covariance matrix can be decomposed into standard deviation and the correlation matrix. + +$$\Sigma = diag(\sigma_\zeta) C_\zeta \, diag(\sigma_\zeta),$$ + +where $\sigma_\zeta$ is the vector of standard deviations, and $C$ is the correlation matrix. HVI further assumes that uncertainties of site parameters, $\zeta_{M1}, \zeta_{M2}, \ldots$, differ only by their standard deviation, i.e. that the parameter correlations is the same and independent of other sites. With the additional assumption of $\zeta_{Ms}$ being independent of $\zeta_P$, the covariance matrix has a block-diagonal structure with one block for $\zeta_P$ and $n_{site}$ repetitions of a block for $\zeta_{M}$. By definition of a correlation matrix, all the main diagonal elements are 1. E.g. for 2 elements in $\zeta_{P}$ and 3 enlements in $\zeta_{M}$ this results in: + +$$\begin{pmatrix} +\begin{matrix} 1 & \rho_{Pab} \\ \rho_{Pab} & 1 \end{matrix} +& 0 & 0 & \cdots\\ +0 & +\begin{matrix} 1 & \rho_{Mab} & \rho_{Mac} \\ \rho_{Mab} & 1 & \rho_{Mbc} \\ \rho_{Mac} & \rho_{Mbc} & 1 \end{matrix} +& 0 +\\ +0 & 0 & +\begin{matrix} 1 & \rho_{Mab} & \rho_{Mac} \\ \rho_{Mab} & 1 & \rho_{Mbc} \\ \rho_{Mac} & \rho_{Mbc} & 1 \end{matrix} +\\ +\cdots & & & \ddots +\end{pmatrix}$$ + +In order to draw random numbers from such a normal distribution, the Cholesky +decomposition of the covariance matrix is required: $\Sigma = U_{\Sigma}^T U_{\Sigma} = +diag(\sigma_\zeta)^T U_C^T U_C \, diag(\sigma_\zeta)$, where $U_{\Sigma}$ and $U_C$ are +the cholesky factors of the covariance and correlation matrices respectively. They are +upper triangular matrices. + +Since, the block-diagonal structure of the correlation matrix carries over to the cholesky +factor, $U_C$ is a block-diagonal matrix of smaller cholesky factors. If HVI modeled the +depence between $\zeta_{Ms}$ and $\zeta_P$, the correlation matrhix would have an +additional block repeated in the first row and its transpose repeated in the first column +in $\Sigma$, leading to a cholesky factor $U_C$ having entries in all the rows. + +HVI allows +to accoung for correlations among those +parameters by providing the values of the global parameters to the machine learning +model, $g$ in addition to the covariates. + +$$ +p(\zeta_{Ms}, \zeta_P) = p(\zeta_{Ms} | \zeta_P) p(\zeta_P)$$ + +Since the differentiation through a general cholesky-decomposition is problematic, +HVI directly parameterizes the Cholesky factor of the correlation matrix rather than the +correlation matrix itself. For details see the Wutzler in prep. + +## Combining variational inference (VI) with hybrid models + +Traditional VI estimates all means and uncertainties of the parameters +$(\zeta_P, \zeta_{M1}, \ldots, \zeta_{Mn} )$ by inverting the model given the observations +and its uncertainty. HVI, directly inverts only the means of $\zeta_P$ +and predicts the means, $\mu_{\zeta_{Ms}}$ of the covariate-dependent parameters +by the machine learning model $q(X_M, \zeta_P; \phi_q)$. +If there is enough information in the observations, the ML model could predict additional +parameters of the posterior distribution based on covariates, such as diagonals of the +covariance matrix. + +Currently, HVI assumes +$\zeta_{Ms} \sim \mathcal{N}(\mu_{\zeta_{Ms}}, \Sigma(\mu_{\zeta_{Ms}}))$ +is normally distributed with the covariance matrix $\Sigma$ being only dependent on the +magnitude of $\mu_{\zeta_{Ms}}$, i.e. conditionally independent of covariates, $X_M$ +(see details on $\mu_{\zeta_{Ms}}$). + +In the specific setting, the parameter vector to be opmized, +$\phi = (\phi_P, \phi_g, \phi_u)$, comprises +- $\phi_P = \mu_{\zeta_P}$: the means of the distributions of the transformed global + parameters, +- $\phi_g$: the parameters of the machine learning model, and +- $\phi_u$: paramerization of $\Sigma_\zeta$ that is additional to the means. + +### Details +Specifically, $\phi_u= (log\sigma^2_P, log\sigma^2_{M0}, log\sigma^2_{M\eta}, a_P, a_M)$, +where the variance of $\zeta_P$ is $\sigma^2_P$, the variance of the $i^{th}$ entry of +$\zeta_{M}$ scales with its magnitude: +$\log \sigma^2_{Mi} = log\sigma^2_{M0i} + log\sigma^2_{M\eta i} \, \mu_{\zeta_{Mi}}$, +and $a_P$ and $a_M$ are parameter vectors of the blocks of the correlation matrix. + +In order to account for correlations between global and site-specific parameters, +HVI models $p(\zeta)$ as a multivariate normal distribution that is a shifted +zero-zentered multivariate normal, $p(\zeta_r)$. + +$$\zeta = (\zeta_P, \zeta_{Ms}) = \zeta_r + (\mu_{\zeta_P}, \mu_{\zeta_{Ms}}) +\\ +\zeta_r = (\zeta_{rP}, \zeta_{rMs}) \sim \mathcal{N}(0, diag(\sigma_\zeta)^T C_\zeta \, diag(\sigma_\zeta)) +\\ +\mu_{\zeta_{Ms}} = g_s(X_M, \zeta_P; \phi_g) +\\ +diag(\sigma_\zeta) = e^{(\log\sigma^2_P, \log\sigma^2_{M})/2} +\\ +C_\zeta = U^T U +\\ +U = \operatorname{BlockDiag}(a_P, a_M)$$ + +where the predicted value of $\mu_{\zeta_{Ms}}$ may depend on the random draw value of $\zeta_P = \zeta_{r,P} + \mu_{\zeta_P}$. By this construction HVI better supports the assumption that $\zeta_{rM}$ is conditionally independent of $\zeta_{rP}$, which is required to macke the cholesky-factor, $U$ of the covariance matrix block-diagonal. + +The above procedure makes an additional subtle approximation. HVI allows the variance of $\zeta_{M}$ to scale with its magnitude. In the computation of the correlation matrix, however, HVI uses the mean, $\mu_{\zeta_{Mi}}$, rather than the actual sampled value, $\zeta_{Mi}$. If it used the actual value, then the distribution of $\zeta$ would need to be described as a general distribution, $p(\zeta) = p(\zeta_{Ms}|\zeta_P) \, p(\zeta_P)$, that would not be normal any more, and HVI could not compute the expectation by drawing centered normla random numbers. + +#### Implementation of the cost function +In practical terms the cost function +- generates normally distributed random values $(\zeta_{rP}, \zeta_{rMs})$ based on the cholesky factor of the covariance matrix, which depends on optimized parameters $(a_P, a_M)$ +- generates a sample of $\zeta_P$ by adding optimized parameters $\mu_{\zeta_P}$ to $\zeta_{rP}$ +- computes expected value of $\mu_{\zeta_{Ms}}$ using the machine learning model given covariates, $X_M$, given $\zeta_P$, and given optimized parameters $\phi_g$. +- generates a sample of $\zeta_{Ms}$ by adding the computed $\mu_{\zeta_{Ms}}$ to $\zeta_{rMs}$ +- transforms $(\zeta_{P}, \zeta_{Ms})$ to the original scale to get a sample of model parameters $(\theta_{rP}, \theta_{rMs})$ +- computes negative Log-density of observations for each sample using the physical model, $f$, and subtract the absolute determinant of the transformation, evaluated at the sample. +- approximates the expected value of the former by taking the mean across the samples +- subtract the entropy of the normal distribution approximator + +The automatic differentiation through this cost function including calls to $g$, T, and $f$ allows to estimate parameters, $\phi$, by a stochastic gradient decent method. + diff --git a/docs/src/index.md b/docs/src/index.md index f4a9cce..5d09653 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,9 +6,5 @@ CurrentModule = HybridVariationalInference Documentation for [HybridVariationalInference](https://github.com/EarthyScience/HybridVariationalInference.jl). -```@index -``` -```@autodocs -Modules = [HybridVariationalInference] -``` + diff --git a/docs/src/problem.md b/docs/src/problem.md new file mode 100644 index 0000000..1a2b295 --- /dev/null +++ b/docs/src/problem.md @@ -0,0 +1,20 @@ +# Problem + +Consider the case of Parameter learning, a special case of hybrid models, +where a machine learning model, $g_{\phi_g}$, uses known covariates $x_{Mi}$ at site $i$, +to predict a subset of the parameters, $\theta$ of the process based model, $f$. + +We are interested in both, +- the uncertainty of hybrid model predictions, $ŷ$ (predictive posterior), and +- the uncertainty of process-model parameters $\theta$, including their correlations + (posterior) + +For example we have soil organic matter process-model that predicts carbon stocks for +different sites. We need to parameterize the unknown carbon use efficiency (CUE) of the soil +microbial community that differs by site, but is hypothesized to correlate with climate variables +and pedogenic factors, such as clay content. +We apply a machine learning model to estimate CUE and fit it end-to-end with other +parameters of the process-model to observed carbon stocks. +In addition to the predicted CUE, we are interested in the uncertainty of CUE and its correlation with other parameters, such a the capacity of the soil minerals to bind carbon. +I.e. we are interetes in the entire posterior probability distribution of the model parameters. + diff --git a/docs/src/reference/reference_internal.md b/docs/src/reference/reference_internal.md new file mode 100644 index 0000000..53b2527 --- /dev/null +++ b/docs/src/reference/reference_internal.md @@ -0,0 +1,19 @@ + + +``` @meta +CurrentModule = HybridVariationalInference +``` + +# Reference of internal functions + +In this reference, you will find a detailed overview of internal functions. +They are documented here mostly for development of the package. +They are not part of the public API and may change without notice. + +``` @autodocs +Modules = [ + HybridVariationalInference, +] +Public = false +``` + diff --git a/docs/src/reference/reference_public.md b/docs/src/reference/reference_public.md new file mode 100644 index 0000000..891df45 --- /dev/null +++ b/docs/src/reference/reference_public.md @@ -0,0 +1,18 @@ + + +``` @meta +CurrentModule = HybridVariationalInference +``` + +# Reference + +In this reference, you will find a detailed overview of the package API, +i.e. the docstrings. + +``` @autodocs +Modules = [ + HybridVariationalInference, HybridVariationalInference.DoubleMM +] +Private = false +``` + diff --git a/docs/src/tutorials/Project.toml b/docs/src/tutorials/Project.toml new file mode 100644 index 0000000..e828ad1 --- /dev/null +++ b/docs/src/tutorials/Project.toml @@ -0,0 +1,16 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" +HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/src/tutorials/_pbm_matrix.qmd b/docs/src/tutorials/_pbm_matrix.qmd new file mode 100644 index 0000000..8d02eae --- /dev/null +++ b/docs/src/tutorials/_pbm_matrix.qmd @@ -0,0 +1,20 @@ +```{julia} +function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) + # extract several covariates from xP + ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (CA.getdata(xPc[:S1,:])::ST) + S2 = (CA.getdata(xPc[:S2,:])::ST) + # + # extract the parameters as row-repeated vectors + n_obs = size(S1, 1) + VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + p1 = CA.getdata(θc[:, par]) ::VT + repeat(p1', n_obs) # matrix: same for each concentration row in S1 + end + # + # each variable is a matrix (n_obs x n_site) + r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) +end +``` + diff --git a/docs/src/tutorials/basic_cpu.md b/docs/src/tutorials/basic_cpu.md new file mode 100644 index 0000000..182c27e --- /dev/null +++ b/docs/src/tutorials/basic_cpu.md @@ -0,0 +1,359 @@ +# Basic workflow without GPU + + +``` @meta +CurrentModule = HybridVariationalInference +``` + +First load necessary packages. + +``` julia +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays: ComponentArrays as CA +using Bijectors +using StableRNGs +using SimpleChains +using StatsFuns +using MLUtils +using DistributionFits +``` + +Next, specify many moving parts of the Hybrid variational inference (HVI) + +## The process-based model + +The example process based model (PBM) predicts a double-monod constrained rate +for different substrate concentrations, `S1`, and `S2`. + +$$ +y = r_0+ r_1 \frac{S_1}{K_1 + S_1} \frac{S_2}{K_2 + S_2}$$ + +``` julia +function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET + # extract parameters not depending on order, i.e whether they are in θP or θM + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + CA.getdata(θc[par])::ET + end + r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) +end +``` + +Its formulation is independent of which parameters are global, site-specific, +or fixed during the model inversion. +However, it cannot assume an ordering in the parameters, but needs to +access the components by its symbolic names in the provided `ComponentArray`. + +## Likelihood function + +HVI requires the evaluation of the likelihood of the predictions. +It corresponds to the cost of predictions given some observations. + +The user specifies a function of the negative log-Likehood +`neg_logden(obs, pred, uncertainty_parameters)`, +where all of the parameters are arrays with columns for sites. + +Here, we use the [`neg_logden_indep_normal`](@ref) function +that assumed observations to be distributed independently +normal around a true value. +The provided `y_unc` uncertainty parameters, here, corresponds to +`logσ2`, denoting the log of the variance parameter of the normal distribution. + +``` julia +py = neg_logden_indep_normal +``` + +## Global-Site, transformations, and priors + +### Global and site-specific parameters + +In this example, we will assign a fixed value to r0 parameter, treat +the K2 parameter as unknown but the same across sites, and predict +r1 and K1 for each site separately, based on covariates known at the sites. + +Here we provide initial values for them by using `ComponentVector`. + +``` julia +FT = Float32 +θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual +θP0 = θP = CA.ComponentVector{FT}(K2=2.0) # population: same across individuals, +θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimated +``` + +### Parameter Transformations + +HVI allows for transformations of parameters in an unconstrained space, +where the probability density is not strictly zero anywhere to the original +constrained space. + +Here, our model parameters are strictly positive, and we use the exponential function +to transform unconstrained estimates to the original constrained domain. + +``` julia +transP = Stacked(HVI.Exp()) +transM = Stacked(HVI.Exp(), HVI.Exp()) +``` + +Parameter transformations are specified using the `Bijectors` package. +Because, `Bijectors.elementwise(exp)`, has problems with automatic differentiation (AD) +on GPU, we use the public but non-exported [`Exp`](@ref) wrapper inside `Bijectors.Stacked`. + +### Prior information on parameters at constrained scale + +HVI is an approximate bayesian analysis and combines prior information on +the parameters with the model and observed data. + +Here, we provide a wide prior by fitting a Lognormal distributions to +- the mean corresponding to the initial value provided above +- the 0.95-quantile 3 times the mean +using the `DistributionFits.jl` package. + +``` julia +θall = vcat(θP, θM) +priors_dict = Dict{Symbol, Distribution}( + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) +``` + +## Observations, model drivers and covariates + +The model parameters are inverted using information on the +- observed data, `y_o` +- its uncertainty, `y_unc` +- known covariates across sites, `xM` +- model drivers, `xP` +Here, we use synthetic data generated by the package. + +``` julia +rng = StableRNG(111) +(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic( + rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,))) +``` + +Lets look at them. + +``` julia +size(xM), size(xP), size(y_o), size(y_unc) +``` + + ((5, 800), (16, 800), (8, 800), (8, 800)) + +All of them have 800 columns, corresponding to 800 sites. +There are 5 site-covaritas, 16 values of model drivers, and 8 observations per site. + +``` julia +xP[:,1] +``` + + ComponentVector{Float32}(S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1], S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0]) + +In each column of the model drivers there is a ComponentVector with +components S1 and S2 corresponding to the concentrations, for which outputs +were observed. +This allows notation `x.S1` in the PBM above. + +The `y_unc` becomes its meaning by the Likelihood-function to be specified with +the problem below. + +### Providing data in batches + +HVI uses `MLUtils.DataLoader` to provide baches of the data during each +iteration of the solver. In addition to the data, it provides an +index to the sites inside a tuple. + +``` julia +n_site = size(y_o,2) +n_batch = 20 +train_dataloader = MLUtils.DataLoader( + (xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false) +``` + +## The Machine-Learning model + +The machine-learning (ML) part predicts parameters of the posterior of site-specific +PBM parameters, given the covariates. +Here, we specify a 3-layer feed-forward neural network using the `SimpleChains` +framework which works efficiently on CPU. + +``` julia +n_out = length(θM) # number of individuals to predict +n_input = n_covar = size(xM,1) + +g_chain = SimpleChain( + static(n_input), # input dimension (optional) + TurboDense{true}(tanh, n_input * 4), + TurboDense{true}(tanh, n_input * 4), + # dense layer without bias that maps to n outputs to (0..1) + TurboDense{false}(logistic, n_out) +) +# get a template of the parameter vector, ϕg0 +g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain) +``` + +The `g_chain_app` `ChainsApplicator` predicts the parameters of the posterior, +approximation given a vector of ML weights,`ϕg`. +During construction, an initial template of this vector is created. +This abstraction layer allows to use different ML frameworks and replace the +`SimpleChains` model by `Flux` or `Lux`. + +### Using priors to scale ML-parameter estimates + +In order to balance gradients, the `g_chain_app` ModelApplicator defined above +predicts on a scale (0..1). +Now the priors are used to translate this to the parameter range by using the +cumulative density distribution. + +Priors were specified at constrained scale, but the ML model predicts +parameters on unconstrained scale. +This transformation of the distribution can be mathematically worked out for +specific prior distribution forms. +However, for simplicity, a [`NormalScalingModelApplicator`](@ref) +is fitted to the transformed 5% and 95% quantiles of the original prior. + +``` julia +priorsM = [priors_dict[k] for k in keys(θM)] +lowers, uppers = get_quantile_transformed(priorsM, transM) +g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT) +``` + +The `g_chain_scaled` `ModelApplicator` now predicts in unconstrained scale, +transforms logistic predctions around 0.5 to the range of +high prior probability of the parameters, +and transforms ML predictions near 0 or 1 towards the outer lower probability ranges. + +## Assembling the information + +All the specifications above are stored in a [`HybridProblem`](@ref) structure. + +Before, a [`PBMSiteApplicator`](@ref) is constructed that translates an invocation +given a vector of global parameters, and a matrix of site parameters to +invocation of the process based model (PBM), defined at the beginning. + +``` julia +f_batch = f_allsites = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1]) + +prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0, + f_batch, f_allsites, priors_dict, py, + transM, transP, train_dataloader, n_covar, n_site, n_batch) +``` + +## Perform the inversion + +Eventually, having assembled all the moving parts of the HVI, we can perform +the inversion. + +``` julia +using OptimizationOptimisers +import Zygote + +solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + +(; probo, interpreters) = solve(prob, solver; rng, + callback = callback_loss(100), # output during fitting + epochs = 2, +); +``` + +The solver object is constructed given the specific stochastic optimization algorithm +and the number of Monte-Carlo samples that are drawn in each iteration +from the predicted parameter posterior. + +Then the solver is applied to the problem using [`solve`](@ref) +for a given number of iterations or epochs. +For this tutorial, we additionally specify that the function to transfer structures to +the GPU is the identity function, so that all stays on the CPU, and this tutorial +hence does not require ad GPU or GPU livraries. + +Among the return values are +- `probo`: A copy of the HybridProblem, with updated optimized parameters +- `interpreters`: A `NamedTuple` with several `ComponentArrayInterpreter`s that +will help analyzing the results. + +## Using a population-level process-based model + +So far, the process-based model ram for each single site. +For this simple model, some performance grains result from matrix-computations +when running the model for all sites within one batch simultaneously. + +In the following, the PBM specification accepts matrices as arguments +for parameters and drivers +and returns a matrix of precitions. +For the parameters, one row corresponds to +one site. For the drivers and predictions, one column corresponds to one site. + +``` julia +function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) + # extract several covariates from xP + ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (CA.getdata(xPc[:S1,:])::ST) + S2 = (CA.getdata(xPc[:S2,:])::ST) + # + # extract the parameters as row-repeated vectors + n_obs = size(S1, 1) + VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + p1 = CA.getdata(θc[:, par]) ::VT + repeat(p1', n_obs) # matrix: same for each concentration row in S1 + end + # + # each variable is a matrix (n_obs x n_site) + r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) +end +``` + +Again, the function should not rely on the order of parameters but use symbolic indexing +to extract the parameter vectors. For type stability of this symbolic indexing, +it uses a workaround to get the type of a single row. +Similarly, it uses type hints to index into the drivers, `xPc`, to extract +sub-matrices by symbols. Alternatively, here it could rely on the structure and +ordering of the columns in `xPc`. + +A corresponding [`PBMPopulationApplicator`](@ref) transforms calls with +partitioned global and site parameters to calls of this matrix version of the PBM. +The HVI Problem needs to be updated with this new applicatior. + +``` julia +f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1]) +f_allsites = PBMPopulationApplicator(f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1]) +probo_sites = HybridProblem(probo; f_batch, f_allsites) +``` + +For numerical efficiency, the number of sites within one batch is part of the +`PBMPopulationApplicator`. Hence, we have two different functions, one applied +to a batch of site, and another applied to all sites. + +As a test of the new applicator, the results are refined by running a few more +epochs of the optimization. + +``` julia +(; probo) = solve(probo_sites, solver; rng, + callback = callback_loss(100), # output during fitting + epochs = 10, + #is_inferred = Val(true), # activate type-checks +); +``` + +## Saving the results + +Extracting useful information from the optimized HybridProblem is covered +in the following [Inspect results of fitted problem](@ref) tutorial. +In order to use the results from this tutorial in other tutorials, +the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file. + +Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref) +of the PBM in module `DoubleMM` rather than +module `Main` to allow for easier reloading with JLD2. + +``` julia +f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1]) +f_allsites = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1]) +probo2 = HybridProblem(probo; f_batch, f_allsites) +``` + +``` julia +using JLD2 +fname = "intermediate/basic_cpu_results.jld2" +mkpath("intermediate") +if probo2 isa AbstractHybridProblem # do not save on failure above + jldsave(fname, false, IOStream; probo=probo2, interpreters) +end +``` diff --git a/docs/src/tutorials/basic_cpu.qmd b/docs/src/tutorials/basic_cpu.qmd new file mode 100644 index 0000000..af96b35 --- /dev/null +++ b/docs/src/tutorials/basic_cpu.qmd @@ -0,0 +1,382 @@ +--- +title: "Basic workflow without GPU" +engine: julia +execute: + echo: true + output: false + daemon: 3600 +format: + commonmark: + variant: -raw_html + wrap: preserve +bibliography: twutz_txt.bib +--- + +``` @meta +CurrentModule = HybridVariationalInference +``` + +First load necessary packages. +```{julia} +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays: ComponentArrays as CA +using Bijectors +using StableRNGs +using SimpleChains +using StatsFuns +using MLUtils +using DistributionFits +``` + +Next, specify many moving parts of the Hybrid variational inference (HVI) + +## The process-based model +The example process based model (PBM) predicts a double-monod constrained rate +for different substrate concentrations, `S1`, and `S2`. + +$$ +y = r_0+ r_1 \frac{S_1}{K_1 + S_1} \frac{S_2}{K_2 + S_2}$$ + +```{julia} +function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET + # extract parameters not depending on order, i.e whether they are in θP or θM + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + CA.getdata(θc[par])::ET + end + r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) +end +``` + +Its formulation is independent of which parameters are global, site-specific, +or fixed during the model inversion. +However, it cannot assume an ordering in the parameters, but needs to +access the components by its symbolic names in the provided `ComponentArray`. + +## Likelihood function + +HVI requires the evaluation of the likelihood of the predictions. +It corresponds to the cost of predictions given some observations. + +The user specifies a function of the negative log-Likehood +`neg_logden(obs, pred, uncertainty_parameters)`, +where all of the parameters are arrays with columns for sites. + +Here, we use the [`neg_logden_indep_normal`](@ref) function +that assumed observations to be distributed independently +normal around a true value. +The provided `y_unc` uncertainty parameters, here, corresponds to +`logσ2`, denoting the log of the variance parameter of the normal distribution. + +```{julia} +py = neg_logden_indep_normal +``` + +## Global-Site, transformations, and priors +### Global and site-specific parameters +In this example, we will assign a fixed value to r0 parameter, treat +the K2 parameter as unknown but the same across sites, and predict +r1 and K1 for each site separately, based on covariates known at the sites. + +Here we provide initial values for them by using `ComponentVector`. + +```{julia} +FT = Float32 +θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual +θP0 = θP = CA.ComponentVector{FT}(K2=2.0) # population: same across individuals, +θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimated +``` +### Parameter Transformations +HVI allows for transformations of parameters in an unconstrained space, +where the probability density is not strictly zero anywhere to the original +constrained space. + +Here, our model parameters are strictly positive, and we use the exponential function +to transform unconstrained estimates to the original constrained domain. + +```{julia} +transP = Stacked(HVI.Exp()) +transM = Stacked(HVI.Exp(), HVI.Exp()) +``` + +Parameter transformations are specified using the `Bijectors` package. +Because, `Bijectors.elementwise(exp)`, has problems with automatic differentiation (AD) +on GPU, we use the public but non-exported [`Exp`](@ref) wrapper inside `Bijectors.Stacked`. + +### Prior information on parameters at constrained scale + +HVI is an approximate bayesian analysis and combines prior information on +the parameters with the model and observed data. + +Here, we provide a wide prior by fitting a Lognormal distributions to +- the mean corresponding to the initial value provided above +- the 0.95-quantile 3 times the mean +using the `DistributionFits.jl` package. + +```{julia} +θall = vcat(θP, θM) +priors_dict = Dict{Symbol, Distribution}( + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) +``` + +## Observations, model drivers and covariates + +The model parameters are inverted using information on the +- observed data, `y_o` +- its uncertainty, `y_unc` +- known covariates across sites, `xM` +- model drivers, `xP` +Here, we use synthetic data generated by the package. + +```{julia} +rng = StableRNG(111) +(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic( + rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,))) +``` + +```{julia} +#| echo: false +#| eval: false +() -> begin + (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) = + gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,))) +end +``` + +Lets look at them. +```{julia} +#| output: true +size(xM), size(xP), size(y_o), size(y_unc) +``` +All of them have 800 columns, corresponding to 800 sites. +There are 5 site-covaritas, 16 values of model drivers, and 8 observations per site. + +```{julia} +#| output: true +xP[:,1] +``` +In each column of the model drivers there is a ComponentVector with +components S1 and S2 corresponding to the concentrations, for which outputs +were observed. +This allows notation `x.S1` in the PBM above. + +The `y_unc` becomes its meaning by the Likelihood-function to be specified with +the problem below. + +### Providing data in batches + +HVI uses `MLUtils.DataLoader` to provide baches of the data during each +iteration of the solver. In addition to the data, it provides an +index to the sites inside a tuple. + +```{julia} +n_site = size(y_o,2) +n_batch = 20 +train_dataloader = MLUtils.DataLoader( + (xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false) +``` + +## The Machine-Learning model + +The machine-learning (ML) part predicts parameters of the posterior of site-specific +PBM parameters, given the covariates. +Here, we specify a 3-layer feed-forward neural network using the `SimpleChains` +framework which works efficiently on CPU. + +```{julia} +n_out = length(θM) # number of individuals to predict +n_input = n_covar = size(xM,1) + +g_chain = SimpleChain( + static(n_input), # input dimension (optional) + TurboDense{true}(tanh, n_input * 4), + TurboDense{true}(tanh, n_input * 4), + # dense layer without bias that maps to n outputs to (0..1) + TurboDense{false}(logistic, n_out) +) +# get a template of the parameter vector, ϕg0 +g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain) +``` + +The `g_chain_app` `ChainsApplicator` predicts the parameters of the posterior, +approximation given a vector of ML weights,`ϕg`. +During construction, an initial template of this vector is created. +This abstraction layer allows to use different ML frameworks and replace the +`SimpleChains` model by `Flux` or `Lux`. + +### Using priors to scale ML-parameter estimates + +In order to balance gradients, the `g_chain_app` ModelApplicator defined above +predicts on a scale (0..1). +Now the priors are used to translate this to the parameter range by using the +cumulative density distribution. + +Priors were specified at constrained scale, but the ML model predicts +parameters on unconstrained scale. +This transformation of the distribution can be mathematically worked out for +specific prior distribution forms. +However, for simplicity, a [`NormalScalingModelApplicator`](@ref) +is fitted to the transformed 5% and 95% quantiles of the original prior. + +```{julia} +priorsM = [priors_dict[k] for k in keys(θM)] +lowers, uppers = get_quantile_transformed(priorsM, transM) +g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT) +``` + +The `g_chain_scaled` `ModelApplicator` now predicts in unconstrained scale, +transforms logistic predctions around 0.5 to the range of +high prior probability of the parameters, +and transforms ML predictions near 0 or 1 towards the outer lower probability ranges. + + +## Assembling the information + +All the specifications above are stored in a [`HybridProblem`](@ref) structure. + +Before, a [`PBMSiteApplicator`](@ref) is constructed that translates an invocation +given a vector of global parameters, and a matrix of site parameters to +invocation of the process based model (PBM), defined at the beginning. + +```{julia} +f_batch = f_allsites = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1]) + +prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0, + f_batch, f_allsites, priors_dict, py, + transM, transP, train_dataloader, n_covar, n_site, n_batch) +``` + +```{julia} +#| eval: false +#| echo: false + +# test invoking +#θMs = stack(Iterators.repeated(θM, n_batch); dims=1) +θMs = θM' .+ (randn(n_batch, size(θM,1)) .* 0.05) +x_batch = xP[:,1:n_batch] +y1 = f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + + +() -> begin + y1 - y_o[:,1:n_batch] # check size and roughly equal + #using Test + #@inferred f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + @inferred Vector{Float64} f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + #using Cthulhu + #@descend_code_warntype f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch)) + prob0 = HVI.DoubleMM.DoubleMMCase() + f_batch0 = get_hybridproblem_PBmodel(prob0; use_all_sites = false) + y1f = f_batch0(θP, θMs, x_batch)[2] + y1 .- y1f # equal +end +``` + +## Perform the inversion + +Eventually, having assembled all the moving parts of the HVI, we can perform +the inversion. + +```{julia} +using OptimizationOptimisers +import Zygote + +solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + +(; probo, interpreters) = solve(prob, solver; rng, + callback = callback_loss(100), # output during fitting + epochs = 2, +); +``` + +The solver object is constructed given the specific stochastic optimization algorithm +and the number of Monte-Carlo samples that are drawn in each iteration +from the predicted parameter posterior. + +Then the solver is applied to the problem using [`solve`](@ref) +for a given number of iterations or epochs. +For this tutorial, we additionally specify that the function to transfer structures to +the GPU is the identity function, so that all stays on the CPU, and this tutorial +hence does not require ad GPU or GPU livraries. + +Among the return values are +- `probo`: A copy of the HybridProblem, with updated optimized parameters +- `interpreters`: A `NamedTuple` with several `ComponentArrayInterpreter`s that + will help analyzing the results. + +## Using a population-level process-based model + +So far, the process-based model ram for each single site. +For this simple model, some performance grains result from matrix-computations +when running the model for all sites within one batch simultaneously. + +In the following, the PBM specification accepts matrices as arguments +for parameters and drivers +and returns a matrix of precitions. +For the parameters, one row corresponds to +one site. For the drivers and predictions, one column corresponds to one site. + + +{{< include _pbm_matrix.qmd >}} + +Again, the function should not rely on the order of parameters but use symbolic indexing +to extract the parameter vectors. For type stability of this symbolic indexing, +it uses a workaround to get the type of a single row. +Similarly, it uses type hints to index into the drivers, `xPc`, to extract +sub-matrices by symbols. Alternatively, here it could rely on the structure and +ordering of the columns in `xPc`. + +A corresponding [`PBMPopulationApplicator`](@ref) transforms calls with +partitioned global and site parameters to calls of this matrix version of the PBM. +The HVI Problem needs to be updated with this new applicatior. + +```{julia} +f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1]) +f_allsites = PBMPopulationApplicator(f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1]) +probo_sites = HybridProblem(probo; f_batch, f_allsites) +``` + +For numerical efficiency, the number of sites within one batch is part of the +`PBMPopulationApplicator`. Hence, we have two different functions, one applied +to a batch of site, and another applied to all sites. + +As a test of the new applicator, the results are refined by running a few more +epochs of the optimization. + +```{julia} +(; probo) = solve(probo_sites, solver; rng, + callback = callback_loss(100), # output during fitting + epochs = 10, + #is_inferred = Val(true), # activate type-checks +); +``` + +## Saving the results +Extracting useful information from the optimized HybridProblem is covered +in the following [Inspect results of fitted problem](@ref) tutorial. +In order to use the results from this tutorial in other tutorials, +the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file. + +Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref) +of the PBM in module `DoubleMM` rather than +module `Main` to allow for easier reloading with JLD2. + +```{julia} +f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1]) +f_allsites = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1]) +probo2 = HybridProblem(probo; f_batch, f_allsites) +``` + +```{julia} +using JLD2 +fname = "intermediate/basic_cpu_results.jld2" +mkpath("intermediate") +if probo2 isa AbstractHybridProblem # do not save on failure above + jldsave(fname, false, IOStream; probo=probo2, interpreters) +end +``` + +```{julia} +#| eval: false +#| echo: false +probo = load(fname, "probo"; iotype = IOStream); +``` \ No newline at end of file diff --git a/docs/src/tutorials/how_to_guides/blocks_corr.qmd b/docs/src/tutorials/how_to_guides/blocks_corr.qmd new file mode 100644 index 0000000..c98a8d1 --- /dev/null +++ b/docs/src/tutorials/how_to_guides/blocks_corr.qmd @@ -0,0 +1,440 @@ +--- +title: "How to model blocks of indenpendent parameters in correlation matrix" +engine: julia +execute: + echo: true + output: false + daemon: 3600 +format: + commonmark: + variant: -raw_html + wrap: preserve +bibliography: twutz_txt.bib +--- + +``` @meta +CurrentModule = HybridVariationalInference +``` + +Modelling full correlations among PBM-parameters requires many degrees of +freedom. + +To decrease the number of parameters to estimate, HVI allows to decompose the +correlations among site parameters or the correlations among global parameters +into sub-blocks of independent parameters. + +TODO + +First load necessary packages. +```{julia} +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays: ComponentArrays as CA +using Bijectors +using StableRNGs +using SimpleChains +using StatsFuns +using MLUtils +using DistributionFits +``` + +Next, specify many moving parts of the Hybrid variational inference (HVI) + +## The process-based model +The example process based model (PBM) predicts a double-monod constrained rate +for different substrate concentrations, `S1`, and `S2`. + +$$ +y = r_0+ r_1 \frac{S_1}{K_1 + S_1} \frac{S_2}{K_2 + S_2}$$ + +```{julia} +function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET + # extract parameters not depending on order, i.e whether they are in θP or θM + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + CA.getdata(θc[par])::ET + end + r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) +end +``` + +Its formulation is independent of which parameters are global, site-specific, +or fixed during the model inversion. +However, it cannot assume an ordering in the parameters, but needs to +access the components by its symbolic names in the provided `ComponentArray`. + +## Likelihood function + +HVI requires the evaluation of the likelihood of the predictions. +It corresponds to the cost of predictions given some observations. + +The user specifies a function of the negative log-Likehood +`neg_logden(obs, pred, uncertainty_parameters)`, +where all of the parameters are arrays with columns for sites. + +Here, we use the [`neg_logden_indep_normal`](@ref) function +that assumed observations to be distributed independently +normal around a true value. +The provided `y_unc` uncertainty parameters, here, corresponds to +`logσ2`, denoting the log of the variance parameter of the normal distribution. + +```{julia} +py = neg_logden_indep_normal +``` + +## Templates, transformations, and correlation structure of parameters +### Global and site-specific parameters +In this example, we will assign a fixed value to r0 parameter, treat +the K2 parameter as unknown but the same across sites, and predict +r1 and K1 for each site separately, based on covariates known at the sites. + +Here we provide initial values for them by using `ComponentVector`. + +```{julia} +FT = Float32 +θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual +θP0 = θP = CA.ComponentVector{FT}(K2=2.0) # population: same across individuals, +θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimated +``` +### Parameter Transformations +HVI allows for transformations of parameters in an unconstrained space, +where the probability density is not strictly zero anywhere to the original +constrained space. + +Here, our model parameters are strictly positive, and we use the exponential function +to transform constrained estimates to the original scale. + +```{julia} +transP = Stacked(HVI.Exp()) +transM = Stacked(HVI.Exp(), HVI.Exp()) +``` + +Parameter transformations are specified using the `Bijectors` package. +Because, `Bijectors.elementwise(exp)`, has problems with automatic differentiation (AD) +on GPU, we use the non-exported [`Exp`]() wrapper inside `Bijectors.Stacked`. + +### Prior information on parameters at constrained scale + +HVI is an approximate bayesian analysis and combines prior information on +the parameters with the model and observed data. + +Here, we provide a wide prior by fitting a Lognormal distributions to +- the mean corresponding to the initial value provided above +- the 0.95-quantile 3 times the mean + +```{julia} +θall = vcat(θP, θM) +priors_dict = Dict{Symbol, Distribution}( + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) +``` + + +### Correlation structure +HVI models the posterior of the parameters at unconstrained scale using a +multivariate normal distribution. It estimates a parameterization of the +associated blocks in the correlation matrx and requires a specification +of the block-structure. + +This is done by specifying the positions of the end of the blocks for +the global (P) and the site-specific parameters (M) respectively using +a `NamedTuple` of integer vectors. + +```{julia} +cor_ends = (P=[length(θP)], M=[length(θM)]) +``` + +Here, we specify a single entry each, meaning, there is only one big +block respectively, spanning all parameters. + +### Further parameters for the posterior approximation. +HVI uses additional fitted parameters to represent the means and the +covariance matrix of the posterior distribution of model parameters. +Here, we obtain construct initial estimates. using [`init_hybrid_ϕunc`](@ref) + +```{julia} +#| output: true +ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) +``` + +The returned `ComponentVector` has entries for +- `logσ2_ζP`: log of the estimated marginal variance of global parameters +- `coef_logσ2_ζMs`: coefficients of a linear model of log of + the estimated marginal variance of site-dependent parameters dependent on + the predicted parameter +- `ρsP`, `ρsM`: parameterization of the blocks in the correlation matrix for + global and site-specific parameters respectively. + +## Observations, model drivers and covariates + +The model is inverted using the +- observed data, `y_o` +- its uncertainty, `y_unc` +- known covariates across sites, `xM` +- model drivers, `xP` +Here, we use synthetic data generated by the package. + +```{julia} +rng = StableRNG(111) +scenario = Val((:omit_r0, :covarK2, )) +(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario) +``` + +```{julia} +#| echo: false +#| eval: false +() -> begin + (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) = + gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario) +end +``` + +Lets look at them. +```{julia} +#| output: true +size(xM), size(xP), size(y_o), size(y_unc) +``` +All of them have 800 columns, corresponding to 800 sites. +There are 5 site-covaraitas and 16 values of model drivers, and 8 observations per site. + +```{julia} +#| output: true +xP[:,1] +``` +In each column of the model drivers there is a ComponentVector with +components S1 and S2 corresponding to the concentrations, for which outputs +were observed. + +The `y_unc` becomes its meaning by the Likelihood-function to be specified with +the problem below. + +### Providing data in Batches + +HVI uses `MLUtils.DataLoader` to provide baches of the data during each +iteration of the solver. In addition to the data, it provides an +index to the sites inside a tuple. + +```{julia} +n_site = size(y_o,2) +n_batch = 20 +train_dataloader = MLUtils.DataLoader( + (xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false) +``` + +## The Machine-Learning model + +The machine-learning (ML) part predicts parameters of the posterior of site-specific +PBM parameters, given the covariates. +Here, we specify a 3-layer feed-forward neural network using the `SimpleChains` +framework which works efficiently on CPU. + +```{julia} +n_out = length(θM) # number of individuals to predict +pbm_covars = (:K2,) # global parameters used as input to ML-model +n_covar = size(xM,1) +n_input = n_covar + length(pbm_covars) + +g_chain = SimpleChain( + static(n_input), # input dimension (optional) + TurboDense{true}(tanh, n_input * 4), + TurboDense{true}(tanh, n_input * 4), + # dense layer without bias that maps to n outputs to (0..1) + TurboDense{false}(logistic, n_out) +) +# get a template of the parameter vector, ϕg0 +g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain) +``` + +The `g_chain_app` `ChainsApplicator` precits the parameters of the posterior, +given a vector of ML weights,`ϕg`. +During construction, an initial template of this vector is created. +This abstraction layer allows to use different ML frameworks and replace the +`SimpleChains` model by `Flux` or `Lux`. + +### Conditional independence of global and site parameters +The `pbm_covars` specifies which sampled global parameters should be +provided as inputs to the ML model. + +The reason for providing global parameters as inputs is to allow for +correlations among site-parameters and global parameters, despite +the required conditional independence of parameters. + +### Using priors to scale ML-parameter estimates + +In order to balance gradients, the `g_chain_app` ModelApplicator defined above +predicts on a scale (0..1). +Now the priors are used to translate this to the parameter range by using the +cumulative density distribution. + +However, HVI needs a prior on unconstrained scale, and we need the priors on +this unconstrained scale. +This can be mathematically worked out for specific prior distribution forms. +But for simplicity, here a [`NormalScalingModelApplicator`](@ref) +is fitted to the transformed 5% and 95% quantiles of the original prior. + +```{julia} +priorsM = [priors_dict[k] for k in keys(θM)] +lowers, uppers = get_quantile_transformed(priorsM, transM) +g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT) +``` + +The `g_chain_scaled` `ModelApplicator` now predicts in unconstrained scale +with transforms ML predctions around 0.5 to the range of +high prior probability of the parameters, +and transforms ML predictions near 0 or 1 towards the outer lower probability ranges. + + +## Assembling the information + +All the specifications above are stored in a [`HybridProblem`](@ref) structure. + +Before, a [`PBMSiteApplicator`](@ref) is constructed that translates an invocation +given a vector of global parameters, and a matrix of site parameters to +invocation of the process based model (PBM), defined at the beginning. + +```{julia} +f_batch = f_allsites = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1]) + +prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0, ϕunc0, + f_batch, f_allsites, priors_dict, py, + transM, transP, train_dataloader, n_covar, n_site, n_batch, + cor_ends, pbm_covars) +``` + +```{julia} +#| eval: false +#| echo: false + +# test invoking +#θMs = stack(Iterators.repeated(θM, n_batch); dims=1) +θMs = θM' .+ (randn(n_batch, size(θM,1)) .* 0.05) +x_batch = xP[:,1:n_batch] +y1 = f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + + +() -> begin + y1 - y_o[:,1:n_batch] # check size and roughly equal + #using Test + #@inferred f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + @inferred Vector{Float64} f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + #using Cthulhu + #@descend_code_warntype f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch)) + prob0 = HVI.DoubleMM.DoubleMMCase() + f_batch0 = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = false) + y1f = f_batch0(θP, θMs, x_batch)[2] + y1 .- y1f # equal +end +``` + +## Perform the inversion + +Eventually, having assembled all the moving parts of the HVI, we can perform +the inversion. + +```{julia} +using OptimizationOptimisers +import Zygote + +solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + +(; probo, interpreters) = solve(prob, solver; scenario, rng, + callback = callback_loss(100), # output during fitting + epochs = 2, + gdev = identity, # do not use GPU, here +); +``` + +The solver object is constructed given the specific stochastic optimization algorithm +and the number of Monte-Carlo samples that are drawn in each iteration +from the predicted parameter posterior. + +Then the solver is applied to the problem using [`solve`](@ref) +for a given number of iterations or epochs. +For this tutorial, we additionally specify that the function to transfer structures to +the GPU is the identity function, so that all stays on the CPU, and this tutorial +hence does not require ad GPU or GPU livraries. + +Among the return values are +- `probo`: A copy of the HybridProblem, with updated optimized parameters +- `interpreters`: A `NamedTuple` with several `ComponentArrayInterpreter`s that + will help analyzing the results. + +## Using a population-level process-based model + +So far, we have specified the process-based model to run for a single site. +For this simple model, we can gain some performance from matrix-computations +during running the model for all sites. + +We specify the PBM now to accept a matrices as arguments for parameters and drivers +and returns a matrix of precitions. For the parameters, one row corresponds to +one site. For the drivers and predictions, one column corresponds to one site. + + +```{julia} +function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) + # extract several covariates from xP + ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (CA.getdata(xPc[:S1,:])::ST) + S2 = (CA.getdata(xPc[:S2,:])::ST) + # + # extract the parameters as row-repeated vectors + n_obs = size(S1, 1) + VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + p1 = CA.getdata(θc[:, par]) ::VT + repeat(p1', n_obs) # matrix: same for each concentration row in S1 + end + # + # each variable is a matrix (n_obs x n_site) + r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) +end +``` + +Again, the function should not rely on the order of parameters, but use symbolic indexing +to extract the parameter vectors. For type stability of this symbolic indexing, +it uses a workaround to get the type of a single row. +Similarly, it uses type hints to index into the drivers, `xPc`, to extract +sub-matrices by symbols. Alternatively, here we it could rely on the structure and +ordering of the columns in `xPc`. + +We use the corresponding [`PBMPopulationApplicator`](@ref) +and the HVI Problem. + +```{julia} +f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1]) +f_allsites = PBMPopulationApplicator(f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1]) +probo_sites = HybridProblem(probo; f_batch, f_allsites) +``` + +For numerical efficiency, the number of sites within one batch is part of the +`PBMPopulationApplicator`. Hence, we have two different functions, one applied +to a batch of site, and another applied to all sites. + +As a test of the new applicator, we refine the results by running a few more +epochs of the optimization. + +```{julia} +(; probo) = solve(probo_sites, solver; scenario, rng, + callback = callback_loss(100), # output during fitting + epochs = 10, + gdev = identity, # do not use GPU, here + #is_inferred = Val(true), # activate type-checks +); +``` + +## Saving the results +Extracting useful information from the optimized HybridProblem is covered +in the following tutorial. XXLink + +To use it, we save the `probo` HybridProblem and the interpreters to a JLD2 file. + +```{julia} +using JLD2 +fname = "intermediate/basic_cpu_results.jld2" +mkpath("intermediate") +jldsave(fname, false, IOStream; probo, interpreters) +``` + +```{julia} +#| eval: false +#| echo: false +probo = load(fname, "probo"; iotype = IOStream); +``` \ No newline at end of file diff --git a/docs/src/tutorials/how_to_guides/corr_site_global.qmd b/docs/src/tutorials/how_to_guides/corr_site_global.qmd new file mode 100644 index 0000000..fc5854c --- /dev/null +++ b/docs/src/tutorials/how_to_guides/corr_site_global.qmd @@ -0,0 +1,431 @@ +--- +title: "How to account for correlations between site and global parameters" +engine: julia +execute: + echo: true + output: false + daemon: 3600 +format: + commonmark: + variant: -raw_html + wrap: preserve +bibliography: twutz_txt.bib +--- + +``` @meta +CurrentModule = HybridVariationalInference +``` + +First load necessary packages. +```{julia} +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays: ComponentArrays as CA +using Bijectors +using StableRNGs +using SimpleChains +using StatsFuns +using MLUtils +using DistributionFits +``` + +Next, specify many moving parts of the Hybrid variational inference (HVI) + +## The process-based model +The example process based model (PBM) predicts a double-monod constrained rate +for different substrate concentrations, `S1`, and `S2`. + +$$ +y = r_0+ r_1 \frac{S_1}{K_1 + S_1} \frac{S_2}{K_2 + S_2}$$ + +```{julia} +function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET + # extract parameters not depending on order, i.e whether they are in θP or θM + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + CA.getdata(θc[par])::ET + end + r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) +end +``` + +Its formulation is independent of which parameters are global, site-specific, +or fixed during the model inversion. +However, it cannot assume an ordering in the parameters, but needs to +access the components by its symbolic names in the provided `ComponentArray`. + +## Likelihood function + +HVI requires the evaluation of the likelihood of the predictions. +It corresponds to the cost of predictions given some observations. + +The user specifies a function of the negative log-Likehood +`neg_logden(obs, pred, uncertainty_parameters)`, +where all of the parameters are arrays with columns for sites. + +Here, we use the [`neg_logden_indep_normal`](@ref) function +that assumed observations to be distributed independently +normal around a true value. +The provided `y_unc` uncertainty parameters, here, corresponds to +`logσ2`, denoting the log of the variance parameter of the normal distribution. + +```{julia} +py = neg_logden_indep_normal +``` + +## Templates, transformations, and correlation structure of parameters +### Global and site-specific parameters +In this example, we will assign a fixed value to r0 parameter, treat +the K2 parameter as unknown but the same across sites, and predict +r1 and K1 for each site separately, based on covariates known at the sites. + +Here we provide initial values for them by using `ComponentVector`. + +```{julia} +FT = Float32 +θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual +θP0 = θP = CA.ComponentVector{FT}(K2=2.0) # population: same across individuals, +θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimated +``` +### Parameter Transformations +HVI allows for transformations of parameters in an unconstrained space, +where the probability density is not strictly zero anywhere to the original +constrained space. + +Here, our model parameters are strictly positive, and we use the exponential function +to transform constrained estimates to the original scale. + +```{julia} +transP = Stacked(HVI.Exp()) +transM = Stacked(HVI.Exp(), HVI.Exp()) +``` + +Parameter transformations are specified using the `Bijectors` package. +Because, `Bijectors.elementwise(exp)`, has problems with automatic differentiation (AD) +on GPU, we use the non-exported [`Exp`]() wrapper inside `Bijectors.Stacked`. + +### Prior information on parameters at constrained scale + +HVI is an approximate bayesian analysis and combines prior information on +the parameters with the model and observed data. + +Here, we provide a wide prior by fitting a Lognormal distributions to +- the mean corresponding to the initial value provided above +- the 0.95-quantile 3 times the mean + +```{julia} +θall = vcat(θP, θM) +priors_dict = Dict{Symbol, Distribution}( + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) +``` + + +### Correlation structure +HVI models the posterior of the parameters at unconstrained scale using a +multivariate normal distribution. It estimates a parameterization of the +associated blocks in the correlation matrx and requires a specification +of the block-structure. + +This is done by specifying the positions of the end of the blocks for +the global (P) and the site-specific parameters (M) respectively using +a `NamedTuple` of integer vectors. + +```{julia} +cor_ends = (P=[length(θP)], M=[length(θM)]) +``` + +Here, we specify a single entry each, meaning, there is only one big +block respectively, spanning all parameters. + +### Further parameters for the posterior approximation. +HVI uses additional fitted parameters to represent the means and the +covariance matrix of the posterior distribution of model parameters. +Here, we obtain construct initial estimates. using [`init_hybrid_ϕunc`](@ref) + +```{julia} +#| output: true +ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) +``` + +The returned `ComponentVector` has entries for +- `logσ2_ζP`: log of the estimated marginal variance of global parameters +- `coef_logσ2_ζMs`: coefficients of a linear model of log of + the estimated marginal variance of site-dependent parameters dependent on + the predicted parameter +- `ρsP`, `ρsM`: parameterization of the blocks in the correlation matrix for + global and site-specific parameters respectively. + +## Observations, model drivers and covariates + +The model is inverted using the +- observed data, `y_o` +- its uncertainty, `y_unc` +- known covariates across sites, `xM` +- model drivers, `xP` +Here, we use synthetic data generated by the package. + +```{julia} +rng = StableRNG(111) +scenario = Val((:omit_r0, :covarK2, )) +(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario) +``` + +```{julia} +#| echo: false +#| eval: false +() -> begin + (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) = + gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario) +end +``` + +Lets look at them. +```{julia} +#| output: true +size(xM), size(xP), size(y_o), size(y_unc) +``` +All of them have 800 columns, corresponding to 800 sites. +There are 5 site-covaraitas and 16 values of model drivers, and 8 observations per site. + +```{julia} +#| output: true +xP[:,1] +``` +In each column of the model drivers there is a ComponentVector with +components S1 and S2 corresponding to the concentrations, for which outputs +were observed. + +The `y_unc` becomes its meaning by the Likelihood-function to be specified with +the problem below. + +### Providing data in Batches + +HVI uses `MLUtils.DataLoader` to provide baches of the data during each +iteration of the solver. In addition to the data, it provides an +index to the sites inside a tuple. + +```{julia} +n_site = size(y_o,2) +n_batch = 20 +train_dataloader = MLUtils.DataLoader( + (xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false) +``` + +## The Machine-Learning model + +The machine-learning (ML) part predicts parameters of the posterior of site-specific +PBM parameters, given the covariates. +Here, we specify a 3-layer feed-forward neural network using the `SimpleChains` +framework which works efficiently on CPU. + +```{julia} +n_out = length(θM) # number of individuals to predict +pbm_covars = (:K2,) # global parameters used as input to ML-model +n_covar = size(xM,1) +n_input = n_covar + length(pbm_covars) + +g_chain = SimpleChain( + static(n_input), # input dimension (optional) + TurboDense{true}(tanh, n_input * 4), + TurboDense{true}(tanh, n_input * 4), + # dense layer without bias that maps to n outputs to (0..1) + TurboDense{false}(logistic, n_out) +) +# get a template of the parameter vector, ϕg0 +g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain) +``` + +The `g_chain_app` `ChainsApplicator` precits the parameters of the posterior, +given a vector of ML weights,`ϕg`. +During construction, an initial template of this vector is created. +This abstraction layer allows to use different ML frameworks and replace the +`SimpleChains` model by `Flux` or `Lux`. + +### Conditional independence of global and site parameters +The `pbm_covars` specifies which sampled global parameters should be +provided as inputs to the ML model. + +The reason for providing global parameters as inputs is to allow for +correlations among site-parameters and global parameters, despite +the required conditional independence of parameters. + +### Using priors to scale ML-parameter estimates + +In order to balance gradients, the `g_chain_app` ModelApplicator defined above +predicts on a scale (0..1). +Now the priors are used to translate this to the parameter range by using the +cumulative density distribution. + +However, HVI needs a prior on unconstrained scale, and we need the priors on +this unconstrained scale. +This can be mathematically worked out for specific prior distribution forms. +But for simplicity, here a [`NormalScalingModelApplicator`](@ref) +is fitted to the transformed 5% and 95% quantiles of the original prior. + +```{julia} +priorsM = [priors_dict[k] for k in keys(θM)] +lowers, uppers = get_quantile_transformed(priorsM, transM) +g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT) +``` + +The `g_chain_scaled` `ModelApplicator` now predicts in unconstrained scale +with transforms ML predctions around 0.5 to the range of +high prior probability of the parameters, +and transforms ML predictions near 0 or 1 towards the outer lower probability ranges. + + +## Assembling the information + +All the specifications above are stored in a [`HybridProblem`](@ref) structure. + +Before, a [`PBMSiteApplicator`](@ref) is constructed that translates an invocation +given a vector of global parameters, and a matrix of site parameters to +invocation of the process based model (PBM), defined at the beginning. + +```{julia} +f_batch = f_allsites = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1]) + +prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0, ϕunc0, + f_batch, f_allsites, priors_dict, py, + transM, transP, train_dataloader, n_covar, n_site, n_batch, + cor_ends, pbm_covars) +``` + +```{julia} +#| eval: false +#| echo: false + +# test invoking +#θMs = stack(Iterators.repeated(θM, n_batch); dims=1) +θMs = θM' .+ (randn(n_batch, size(θM,1)) .* 0.05) +x_batch = xP[:,1:n_batch] +y1 = f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + + +() -> begin + y1 - y_o[:,1:n_batch] # check size and roughly equal + #using Test + #@inferred f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + @inferred Vector{Float64} f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch))[2] + #using Cthulhu + #@descend_code_warntype f_batch(CA.getdata(θP), CA.getdata(θMs), CA.getdata(x_batch)) + prob0 = HVI.DoubleMM.DoubleMMCase() + f_batch0 = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = false) + y1f = f_batch0(θP, θMs, x_batch)[2] + y1 .- y1f # equal +end +``` + +## Perform the inversion + +Eventually, having assembled all the moving parts of the HVI, we can perform +the inversion. + +```{julia} +using OptimizationOptimisers +import Zygote + +solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + +(; probo, interpreters) = solve(prob, solver; scenario, rng, + callback = callback_loss(100), # output during fitting + epochs = 2, + gdev = identity, # do not use GPU, here +); +``` + +The solver object is constructed given the specific stochastic optimization algorithm +and the number of Monte-Carlo samples that are drawn in each iteration +from the predicted parameter posterior. + +Then the solver is applied to the problem using [`solve`](@ref) +for a given number of iterations or epochs. +For this tutorial, we additionally specify that the function to transfer structures to +the GPU is the identity function, so that all stays on the CPU, and this tutorial +hence does not require ad GPU or GPU livraries. + +Among the return values are +- `probo`: A copy of the HybridProblem, with updated optimized parameters +- `interpreters`: A `NamedTuple` with several `ComponentArrayInterpreter`s that + will help analyzing the results. + +## Using a population-level process-based model + +So far, we have specified the process-based model to run for a single site. +For this simple model, we can gain some performance from matrix-computations +during running the model for all sites. + +We specify the PBM now to accept a matrices as arguments for parameters and drivers +and returns a matrix of precitions. For the parameters, one row corresponds to +one site. For the drivers and predictions, one column corresponds to one site. + + +```{julia} +function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) + # extract several covariates from xP + ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (CA.getdata(xPc[:S1,:])::ST) + S2 = (CA.getdata(xPc[:S2,:])::ST) + # + # extract the parameters as row-repeated vectors + n_obs = size(S1, 1) + VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + p1 = CA.getdata(θc[:, par]) ::VT + repeat(p1', n_obs) # matrix: same for each concentration row in S1 + end + # + # each variable is a matrix (n_obs x n_site) + r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) +end +``` + +Again, the function should not rely on the order of parameters, but use symbolic indexing +to extract the parameter vectors. For type stability of this symbolic indexing, +it uses a workaround to get the type of a single row. +Similarly, it uses type hints to index into the drivers, `xPc`, to extract +sub-matrices by symbols. Alternatively, here we it could rely on the structure and +ordering of the columns in `xPc`. + +We use the corresponding [`PBMPopulationApplicator`](@ref) +and update the HVI Problem. + +```{julia} +f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1]) +f_allsites = PBMPopulationApplicator(f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1]) +probo_sites = HybridProblem(probo; f_batch, f_allsites) +``` + +For numerical efficiency, the number of sites within one batch is part of the +`PBMPopulationApplicator`. Hence, we have two different functions, one applied +to a batch of site, and another applied to all sites. + +As a test of the new applicator, we refine the results by running a few more +epochs of the optimization. + +```{julia} +(; probo) = solve(probo_sites, solver; scenario, rng, + callback = callback_loss(100), # output during fitting + epochs = 10, + gdev = identity, # do not use GPU, here + #is_inferred = Val(true), # activate type-checks +); +``` + +## Saving the results +Extracting useful information from the optimized HybridProblem is covered +in the following tutorial. XXLink + +To use it, we save the `probo` HybridProblem and the interpreters to a JLD2 file. + +```{julia} +using JLD2 +fname = "intermediate/basic_cpu_results.jld2" +mkpath("intermediate") +jldsave(fname, false, IOStream; probo, interpreters) +``` + +```{julia} +#| eval: false +#| echo: false +probo = load(fname, "probo"; iotype = IOStream); +``` \ No newline at end of file diff --git a/docs/src/tutorials/inspect_results.md b/docs/src/tutorials/inspect_results.md new file mode 100644 index 0000000..d164660 --- /dev/null +++ b/docs/src/tutorials/inspect_results.md @@ -0,0 +1,144 @@ +# Inspect results of fitted problem + + +``` @meta +CurrentModule = HybridVariationalInference +``` + +First load necessary packages. + +``` julia +using HybridVariationalInference +using StableRNGs +using ComponentArrays: ComponentArrays as CA +using SimpleChains # for reloading the optimized problem +using DistributionFits +using JLD2 +using CairoMakie +using PairPlots # scatterplot matrices +``` + +This tutorial uses the fitted object saved in the +[Basic workflow without GPU](@ref) tutorial. + +``` julia +fname = "intermediate/basic_cpu_results.jld2" +print(abspath(fname)) +probo, interpreters = load(fname, "probo", "interpreters"); +``` + +## Sample the posterior + +A sample of both, posterior, and predictive posterior can be obtained +using function [`sample_posterior`](@ref). + +``` julia +using StableRNGs +rng = StableRNG(112) +n_sample_pred = 400 +(; θsP, θsMs) = sample_posterior(rng, probo; n_sample_pred) +``` + +Lets look at the results. + +``` julia +size(θsP), size(θsMs) +``` + + ((1, 400), (800, 2, 400)) + +The last dimension is the number of samples, the second-last dimension is +the respective parameter. `θsMs` has an additional dimension denoting +the site for which parameters are samples. + +They are ComponentArrays with the parameter dimension names that can be used: + +``` julia +θsMs[1,:r1,:] # sample of r1 of the first site +``` + +## Corner plots + +The relation between different variables can be well inspected by +scatterplot matrices, also called corner plots or pair plots. +`PairPlots.jl` provides a Makie-implementation of those. + +Here, we plot the global parameters and the site-parameters for the first site. + +``` julia +i_site = 1 +θ1 = vcat(θsP, θsMs[i_site,:,:]) +θ1_nt = NamedTuple(k => CA.getdata(θ1[k,:]) for k in keys(θ1[:,1])) # +plt = pairplot(θ1_nt) +``` + +![](inspect_results_files/figure-commonmark/cell-8-output-1.png) + +The plot shows that parameters for the first site, *K*₁ and *r*₁, are correlated, +but that we did not model correlation with the global parameter, *K*₂. + +Note that this plots shows only the first out of 800 sites. +HVI estimated a 1602-dimensional posterior distribution including +covariances among parameters. + +## Expected values and marginal variances + +Lets look at how the estimated uncertainty of a site parameter changes with +its expected value. + +``` julia +par = :K1 +θmean = [mean(θsMs[s,par,:]) for s in axes(θsMs, 1)] +θsd = [std(θsMs[s,par,:]) for s in axes(θsMs, 1)] +fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)") +scatter!(ax, θmean, θsd) +fig +``` + +![](inspect_results_files/figure-commonmark/cell-10-output-1.png) + +We see that *K*₁ across sites ranges from about 0.18 to 0.25, and that +its estimated uncertainty is about 0.034, slightly decreasing with the +values of the parameter. + +## Predictive Posterior + +In addition to the uncertainty in parameters, we are also interested in +the uncertainty of predictions, i.e. the predictive posterior. + +We cam either run the PBM for all the parameter samples that we obtained already, +using [`apply_process_model`](@ref), or use [`predict_hvi`](@ref) which combines +sampling the posterior and predictive posterior and returns the additional +`NamedTuple` entry `y`. + +``` julia +(; y, θsP, θsMs) = predict_hvi(rng, probo; n_sample_pred) +``` + +``` julia +size(y) +``` + + (8, 800, 400) + +Again, the last dimension is the sample. +The other dimensions correspond to the observations we provided for the fitting: +The first dimension is the observation within one site, the second dimension is the site. + +Lets look on how the uncertainty of the 4th observations scales with its +predicted magnitude across sites. + +``` julia +i_obs = 4 +ymean = [mean(y[i_obs,s,:]) for s in axes(θsMs, 1)] +ysd = [std(y[i_obs,s,:]) for s in axes(θsMs, 1)] +fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y$i_obs)",ylabel="sd(y$i_obs)") +scatter!(ax, ymean, ysd) +fig +``` + +![](inspect_results_files/figure-commonmark/cell-13-output-1.png) + +We see that observed values for associated substrate concentrations range about from +0.51 to 0.59 with an estimated standard deviation around 0.005 that decreases +with the observed value. diff --git a/docs/src/tutorials/inspect_results.qmd b/docs/src/tutorials/inspect_results.qmd new file mode 100644 index 0000000..a60ae24 --- /dev/null +++ b/docs/src/tutorials/inspect_results.qmd @@ -0,0 +1,172 @@ +--- +title: "Inspect results of fitted problem" +engine: julia +execute: + echo: true + output: false + daemon: 3600 +format: + commonmark: + variant: -raw_html + wrap: preserve +bibliography: twutz_txt.bib +--- + +``` @meta +CurrentModule = HybridVariationalInference +``` + +First load necessary packages. + +```{julia} +using HybridVariationalInference +using StableRNGs +using ComponentArrays: ComponentArrays as CA +using SimpleChains # for reloading the optimized problem +using DistributionFits +using JLD2 +using CairoMakie +using PairPlots # scatterplot matrices +``` + +This tutorial uses the fitted object saved at the end of the +[Basic workflow without GPU](@ref) tutorial. + +```{julia} +fname = "intermediate/basic_cpu_results.jld2" +print(abspath(fname)) +probo, interpreters = load(fname, "probo", "interpreters"); +``` + +```{julia} +#| eval: false +#| echo: false +# not necessary any more with DoubleMM.f_doubleMM_sites +# {{< include _pbm_matrix.qmd >}} +# outside notebook, need to reset ModelApplicator, due to fθ defined in Notebook module +#θFix = CA.ComponentVector{eltype(probo.θP)}(r0=0.3) +θFix = CA.ComponentVector{eltype(probo.θP)}( + r0=probo.f_allsites.θFixm[1]) +_xP_batch = first(probo.train_dataloader)[2] +f_batch = PBMPopulationApplicator( + f_doubleMM_sites, probo.n_batch; probo.θP, probo.θM, θFix, xPvec=_xP_batch[:,1]) +f_allsites = PBMPopulationApplicator( + f_doubleMM_sites, probo.n_site; probo.θP, probo.θM, θFix, xPvec=_xP_batch[:,1]) +probo = HybridProblem(probo; f_batch, f_allsites) +``` +## Sample the posterior + +A sample of both, posterior, and predictive posterior can be obtained +using function [`sample_posterior`](@ref). + +```{julia} +using StableRNGs +rng = StableRNG(112) +n_sample_pred = 400 +(; θsP, θsMs) = sample_posterior(rng, probo; n_sample_pred) +``` + +Lets look at the results. +```{julia} +#| output: true +size(θsP), size(θsMs) +``` +The last dimension is the number of samples, the second-last dimension is +the respective parameter. `θsMs` has an additional dimension denoting +the site for which parameters are samples. + +They are ComponentArrays with the parameter dimension names that can be used: +```{julia} +θsMs[1,:r1,:] # sample of r1 of the first site +``` + +## Corner plots + +The relation between different variables can be well inspected by +scatterplot matrices, also called corner plots or pair plots. +`PairPlots.jl` provides a Makie-implementation of those. + +Here, we plot the global parameters and the site-parameters for the first site. +```{julia} +#| output: true +i_site = 1 +θ1 = vcat(θsP, θsMs[i_site,:,:]) +θ1_nt = NamedTuple(k => CA.getdata(θ1[k,:]) for k in keys(θ1[:,1])) # +plt = pairplot(θ1_nt) +``` +The plot shows that parameters for the first site, $K_1$ and $r_1$, are correlated, +but that we did not model correlation with the global parameter, $K_2$. + +Note that this plots shows only the first out of 800 sites. +HVI estimated a 1602-dimensional posterior distribution including +covariances among parameters. + +```{julia} +#| eval: false +#| echo: false +probo.θP, probo.θM +``` + + +## Expected values and marginal variances + +Lets look at how the estimated uncertainty of a site parameter changes with +its expected value. + +```{julia} +#| output: true +par = :K1 +θmean = [mean(θsMs[s,par,:]) for s in axes(θsMs, 1)] +θsd = [std(θsMs[s,par,:]) for s in axes(θsMs, 1)] +fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)") +scatter!(ax, θmean, θsd) +fig +``` + +We see that $K_1$ across sites ranges from about 0.18 to 0.25, and that +its estimated uncertainty is about 0.034, slightly decreasing with the +values of the parameter. + +## Predictive Posterior + +In addition to the uncertainty in parameters, we are also interested in +the uncertainty of predictions, i.e. the predictive posterior. + +We cam either run the PBM for all the parameter samples that we obtained already, +using [`apply_process_model`](@ref), or use [`predict_hvi`](@ref) which combines +sampling the posterior and predictive posterior and returns the additional +`NamedTuple` entry `y`. + +```{julia} +(; y, θsP, θsMs) = predict_hvi(rng, probo; n_sample_pred) +``` + +```{julia} +#| output: true +size(y) +``` + +Again, the last dimension is the sample. +The other dimensions correspond to the observations we provided for the fitting: +The first dimension is the observation within one site, the second dimension is the site. + +Lets look on how the uncertainty of the 4th observations scales with its +predicted magnitude across sites. + +```{julia} +#| output: true +i_obs = 4 +ymean = [mean(y[i_obs,s,:]) for s in axes(θsMs, 1)] +ysd = [std(y[i_obs,s,:]) for s in axes(θsMs, 1)] +fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y$i_obs)",ylabel="sd(y$i_obs)") +scatter!(ax, ymean, ysd) +fig +``` +We see that observed values for associated substrate concentrations range about from +0.51 to 0.59 with an estimated standard deviation around 0.005 that decreases +with the observed value. + + + + + diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png new file mode 100644 index 0000000..32ef60b Binary files /dev/null and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-11-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-11-output-1.png new file mode 100644 index 0000000..32ef60b Binary files /dev/null and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-11-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png new file mode 100644 index 0000000..f994f9e Binary files /dev/null and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-14-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-14-output-1.png new file mode 100644 index 0000000..f994f9e Binary files /dev/null and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-14-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png new file mode 100644 index 0000000..a6c133b Binary files /dev/null and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-9-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-9-output-1.png new file mode 100644 index 0000000..a6c133b Binary files /dev/null and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-9-output-1.png differ diff --git a/docs/src/tutorials/intermediate/basic_cpu_results.jld2 b/docs/src/tutorials/intermediate/basic_cpu_results.jld2 new file mode 100644 index 0000000..126fa04 Binary files /dev/null and b/docs/src/tutorials/intermediate/basic_cpu_results.jld2 differ diff --git a/docs/src/tutorials/twutz_txt.bib b/docs/src/tutorials/twutz_txt.bib new file mode 100644 index 0000000..1da805e --- /dev/null +++ b/docs/src/tutorials/twutz_txt.bib @@ -0,0 +1,36 @@ +@Article{Wutzler20, + author = {Thomas Wutzler and Oscar Perez-Priego and Kendalynn Morris and Tarek S. El-Madany and Mirco Migliavacca}, + title = {Soil {CO}{\&}lt$\mathsemicolon$sub{\&}gt$\mathsemicolon$2{\&}lt$\mathsemicolon$/sub{\&}gt$\mathsemicolon$ efflux errors are lognormally distributed {\&}amp$\mathsemicolon${\#}8211$\mathsemicolon$ implications and guidance}, + journal = {Geoscientific Instrumentation, Methods and Data Systems}, + year = {2020}, + volume = {9}, + number = {1}, + pages = {239--254}, + month = {may}, + comment = {tomasch 19}, + doi = {10.5194/gi-9-239-2020}, + file = {:2007/stat/Wutzler20_lognormal_soilCO2efflux.pdf:PDF}, + keywords = {stat, tomasch, soil, respiration, chamber}, + owner = {twutz}, + publisher = {Copernicus {GmbH}}, + timestamp = {2020.06.02}, +} + +@Article{Wutzler22, + author = {Thomas Wutzler and Lin Yu and Marion Schrumpf and Sönke Zaehle}, + journal = {Geoscientific Model Development}, + title = {Simulating long-term responses of soil organic matter turnover to substrate stoichiometry by abstracting fast and small-scale microbial processes: the Soil Enzyme Steady Allocation Model ({SESAM}; v3.0)}, + year = {2022}, + month = {nov}, + number = {22}, + pages = {8377--8393}, + volume = {15}, + comment = {tomasch 24}, + doi = {10.5194/gmd-15-8377-2022}, + file = {:tomasch/Wutzler22_SESAM_upscaling.pdf:PDF}, + keywords = {model, soil, enzyme, SESAM, quasi-steady-state-assumption}, + owner = {twutz}, + publisher = {Copernicus {GmbH}}, + timestamp = {2022.11.18}, +} + diff --git a/docs/src_stash/test1.qmd b/docs/src_stash/test1.qmd new file mode 100644 index 0000000..1afa6b2 --- /dev/null +++ b/docs/src_stash/test1.qmd @@ -0,0 +1,78 @@ +--- +title: "A julia engine notebook" +engine: julia +execute: + echo: true +format: + commonmark: + variant: -raw_html + wrap: preserve + html: + code-fold: false +bibliography: twutz_txt.bib +--- + +```{julia} +2 + 2 +``` + +```{julia} +#| echo: false +using Plots +``` + +A code block with `#| echo: false`. +```{julia} +#| echo: false +using Pkg +Pkg.status() +``` + +```{julia} +run(`which julia`); +``` + +## Equation +inline: $\sqrt{\pi}$ inside text. + +display equation @eq-general with a number and a reference + +$$ +\begin{aligned} +\frac{n!}{k!(n - k)!} &= \binom{n}{k} +\\ +x^2 =& 3 +\end{aligned} +$$ {#eq-general} + +## Literature references + +See [@Wutzler20]. + +## Plots + +single +```{julia} +#| label: fig-limits +#| fig-cap: "Errorbar limit selector" +plot(1:4) +``` + + +## Not working with Documenter +Julia code is not colored, output block-code + +Figures: The figure captions are not kept, and links do not work + +References: the number is correctly inserted, but the link does not work + +Documenter admonitions: the four spaces before the contents are removed - do not work + +Quarto-Columns are rendered to several paragraphs + +With gfm, references are formatted bad, and there are errors with unclosed divs +after Documenter processes it. + + + +## References diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index da8275d..1d8f462 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -81,12 +81,4 @@ function HVI.construct_3layer_MLApplicator( construct_ChainsApplicator(rng, g_chain, float_type) end -function HVI.cpu_ca(ca::CA.ComponentArray) - CA.ComponentArray(cpu(CA.getdata(ca)), CA.getaxes(ca)) -end - - - - - end # module diff --git a/src/AbstractHybridProblem.jl b/src/AbstractHybridProblem.jl index 1d4a898..798ed57 100644 --- a/src/AbstractHybridProblem.jl +++ b/src/AbstractHybridProblem.jl @@ -164,7 +164,7 @@ Return a DataLoader that provides a tuple of - `xP`: Iterator of process-model drivers, with one element per site - `y_o`: matrix of observations with added noise, with one column per site - `y_unc`: matrix `sizeof(y_o)` of uncertainty information -- `i_sites`: Vector of indices of sites in toal sitevector for the minibatch +- `i_sites`: Vector of indices of sites in the minibatch """ function get_hybridproblem_train_dataloader end @@ -191,25 +191,23 @@ end """ - gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader, - scenario = (), - gdev = gpu_device(), - gdev_M = :use_gpu ∈ scenario ? gdev : identity, - gdev_P = :f_on_gpu ∈ scenario ? gdev : identity, + gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader; gdev_M, gdev_P, batchsize = dataloader.batchsize, partial = dataloader.partial ) Put relevant parts of the DataLoader to gpu, depending on scenario. """ -function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader; - scenario::Val{scen} = Val(()), - gdev = gpu_device(), - gdev_M = :use_gpu ∈ _val_value(scenario) ? gdev : identity, - gdev_P = :f_on_gpu ∈ _val_value(scenario) ? gdev : identity, +function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader; gdevs, + gdev_M = gdevs.gdev_M, + gdev_P = gdevs.gdev_P, + # scenario::Val{scen} = Val(()), + # gdev = gpu_device(), + # gdev_M = :use_gpu ∈ _val_value(scenario) ? gdev : identity, + # gdev_P = :f_on_gpu ∈ _val_value(scenario) ? gdev : identity, batchsize = dataloader.batchsize, partial = dataloader.partial - ) where scen + ) xM, xP, y_o, y_unc, i_sites = dataloader.data xM_dev = gdev_M(xM) xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc)) @@ -218,6 +216,28 @@ function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader; return(train_loader_dev) end +""" + get_gcdev(scenario::Val{scen}) where scen + +Configure the function that puts data and computations to gpu device +for given `scenario`. +Checking for `:use_gpu` and `:f_on_gpu` in `scenario`. +Returns a `NamedTuple` `(;gdev_M, gdev_P)` +""" +function get_gdev_MP(scenario::Val{scen}) where scen + gdev_gpu = gpu_device() + gdev_M = :use_gpu ∈ _val_value(scenario) ? gdev_gpu : identity + gdev_P = :f_on_gpu ∈ _val_value(scenario) ? gdev_gpu : identity + (;gdev_M, gdev_P) +end + +function infer_cdev(gdevs; gdev_M = gdevs.gdev_M, gdev_P = gdevs.gdev_P) + # if gdev_M is already on CPU use identity, + # if gdev_P is on GPU also use ideneity to not transfer to CPU + cdev=!(gdev_M isa MLDataDevices.AbstractGPUDevice) ? identity : + ((gdev_P isa MLDataDevices.AbstractGPUDevice) ? identity : cpu_device()) +end + # function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ()) # rng::AbstractRNG = Random.default_rng() # get_hybridproblem_train_dataloader(rng, prob; scenario) @@ -265,50 +285,5 @@ function setup_PBMpar_interpreter(θP, θM, θall = vcat(θP, θM)) intθ, θFix end -struct PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT} - θFix::θFixT - θFix_dev::θFix_devT - intθ::StaticComponentArrayInterpreter{AX} - isP::Matrix{Int} - n_site_batch::Int - pos_xP::pos_xPT -end - -function PBmodelClosure(prob::AbstractHybridProblem; scenario::Val{scen}, - use_all_sites = false, - gdev = :f_on_gpu ∈ _val_value(scenario) ? gpu_device() : identity, - θall, int_xP1, -) where {scen} - n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - n_site_batch = use_all_sites ? n_site : n_batch - #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers - par_templates = get_hybridproblem_par_templates(prob; scenario) - intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) - θFix = repeat(θFix1', n_site_batch) - θFix_dev = gdev(θFix) - intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1)) - #int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1) - isP = repeat(axes(par_templates.θP, 1)', n_site_batch) - pos_xP = get_positions(int_xP1) - PBmodelClosure(;θFix, θFix_dev, intθ, isP, n_site_batch, pos_xP) -end - -function PBmodelClosure(; - θFix::θFixT, - θFix_dev::θFix_devT, - intθ::StaticComponentArrayInterpreter{AX}, - isP::Matrix{Int}, - n_site_batch::Int, - pos_xP::pos_xPT, -) where {θFixT, θFix_devT, AX, pos_xPT} - PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT}( - θFix::AbstractArray, θFix_dev, intθ, isP, n_site_batch, pos_xP) -end - - - - - - diff --git a/src/ComponentArrayInterpreter.jl b/src/ComponentArrayInterpreter.jl index 7f15d6c..960d5ee 100644 --- a/src/ComponentArrayInterpreter.jl +++ b/src/ComponentArrayInterpreter.jl @@ -108,8 +108,10 @@ get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{c """ ComponentArrayInterpreter(; kwargs...) ComponentArrayInterpreter(::AbstractComponentArray) + ComponentArrayInterpreter(::AbstractComponentArray, n_dims::NTuple{N,<:Integer}) ComponentArrayInterpreter(n_dims::NTuple{N,<:Integer}, ::AbstractComponentArray) + ComponentArrayInterpreter(n_dims::NTuple{N,<:Integer}, ::AbstractComponentArray, m_dims::NTuple{M,<:Integer}) Construct a `ComponentArrayInterpreter <: AbstractComponentArrayInterpreter` with components being vectors of given length or given model of a `AbstractComponentArray`. @@ -152,19 +154,30 @@ function ComponentArrayInterpreter(vc::CA.AbstractComponentArray) ComponentArrayInterpreter(CA.getaxes(vc)) end +const CAorCAI = Union{CA.AbstractComponentArray, AbstractComponentArrayInterpreter} + # Attach axes to matrices and arrays of ComponentArrays # with ComponentArrays in the first dimensions (e.g. rownames of a matrix or array) -function ComponentArrayInterpreter( - ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where {N} - ComponentArrayInterpreter(CA.getaxes(ca), n_dims) +function ComponentArrayInterpreter(ca::CAorCAI, n_dims::NTuple{N,<:Integer}) where {N} + ComponentArrayInterpreter((), CA.getaxes(ca), n_dims) +end +# with ComponentArrays in the last dimensions (e.g. columnnames of a matrix) +function ComponentArrayInterpreter(n_dims::NTuple{N,<:Integer}, ca::CAorCAI) where {N} + ComponentArrayInterpreter(n_dims, CA.getaxes(ca), ()) end +# with ComponentArrays in the center dimensions (e.g. columnnames of a 3D-array) function ComponentArrayInterpreter( - cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where {N} - ComponentArrayInterpreter(CA.getaxes(cai), n_dims) + n_dims::NTuple{N,<:Integer}, ca::CAorCAI, m_dims::NTuple{M,<:Integer}) where {N,M} + ComponentArrayInterpreter(n_dims, CA.getaxes(ca), m_dims) end + function ComponentArrayInterpreter( - axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N} - axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...) + n_dims::NTuple{N,<:Integer}, axes::NTuple{A,<:CA.AbstractAxis}, + m_dims::NTuple{M,<:Integer}) where {N,A,M} + axes_ext = ( + map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., + axes..., + map(n_dim -> CA.Axis(i=1:n_dim), m_dims)...) ComponentArrayInterpreter(axes_ext) end @@ -182,25 +195,11 @@ function stack_ca_int( IT.name.wrapper(CA.getaxes(cai), n_dims)::IT.name.wrapper end function StaticComponentArrayInterpreter( - axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N} + axes::NTuple{A,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {A,N} axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...) StaticComponentArrayInterpreter{axes_ext}() end -# with ComponentArrays in the last dimensions (e.g. columnnames of a matrix) -function ComponentArrayInterpreter( - n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where {N} - ComponentArrayInterpreter(n_dims, CA.getaxes(ca)) -end -function ComponentArrayInterpreter( - n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where {N} - ComponentArrayInterpreter(n_dims, CA.getaxes(cai)) -end -function ComponentArrayInterpreter( - n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M} - axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...) - ComponentArrayInterpreter(axes_ext) -end function stack_ca_int( ::Val{n_dims}, cai::IT) where {IT<:AbstractComponentArrayInterpreter,n_dims} diff --git a/src/DoubleMM/DoubleMM.jl b/src/DoubleMM/DoubleMM.jl index e0a3d4c..6b39306 100644 --- a/src/DoubleMM/DoubleMM.jl +++ b/src/DoubleMM/DoubleMM.jl @@ -14,7 +14,7 @@ using MLDataDevices import GPUArraysCore # used in conditional breakpoints import StableRNGs -export f_doubleMM, xP_S1, xP_S2 +export f_doubleMM, f_doubleMM_sites, xP_S1, xP_S2 include("f_doubleMM.jl") diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 01b58b2..2d8afbb 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -5,6 +5,9 @@ const θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2) const θall = vcat(θP, θM) const θP_nor0 = θP[(:K2,)] +θP_nor0_K1 = θM[(:K1,)] +θM_nor0_K1 = vcat(θM[(:r1,)], θP[(:K2,)]) + const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1] const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0] @@ -18,49 +21,143 @@ int_xP1 = ComponentArrayInterpreter(CA.ComponentVector(S1 = xP_S1, S2 = xP_S2)) const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM))) -function f_doubleMM(θ::AbstractVector, x; intθ1) +""" + f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET + +Example process based model (PBM) predicts a double-monod constrained rate +for different substrate concentration vectors, `x.S1`, and `x.S2` for a single site. +θc is a ComponentVector with scalar parameters as components: `r0`, `r1`, `K1`, and `K2` + +It predicts a rate for each entry in concentrations: +`y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)`. + +It is defined as +```julia +function f_doubleMM(θc::ComponentVector{ET}, x) where ET + # extract parameters not depending on order, i.e whether they are in θP or θM + # r0 = θc[:r0] + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + getdata(θc[par])::ET + end + y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) + return (y) +end +``` +""" +function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET # extract parameters not depending on order, i.e whether they are in θP or θM - y = GPUArraysCore.allowscalar() do - θc = intθ1(θ) + GPUArraysCore.allowscalar() do # index to scalar parameter in parameter vector + #θc = intθ1(θ) #using ComponentArrays: ComponentArrays as CA #r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] # does not work on Zygote+GPU (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par # vector will be repeated when broadcasted by a matrix - CA.getdata(θc[par]) + CA.getdata(θc[par])::ET end # r0 = θc[:r0] # r1 = θc[:r1] # K1 = θc[:K1] # K2 = θc[:K2] y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) + return (y) end - return (y) end -function f_doubleMM( - θ::AbstractMatrix{T}, x; intθ::HVI.AbstractComponentArrayInterpreter) where T - # provide θ for n_row sites - # provide x.S1 as Matrix n_site x n_obs - # extract parameters not depending on order, i.e whether they are in θP or θM - θc = intθ(θ) - @assert size(x.S1, 1) == size(θ, 1) # same number of sites - @assert size(x.S1) == size(x.S2) # same number of observations - #@assert length(x.s2 == n_obs) - # problems on AD on GPU with indexing CA may be related to printing result, use ";" - VT = typeof(θ[:,1]) # workaround for non-stable Symbol-indexing CAMatrix - #VT = first(Base.return_types(getindex, Tuple{typeof(θ),typeof(Colon()),typeof(1)})) +""" + f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) + +Example process based model (PBM) that predicts for a batch of sites. + +Arguments +- `θc`: parameters with one row per site and symbolic column index +- `xPc`: model drivers with one column per site, and symbolic row index + +Returns a matrix `(n_obs x n_site)` of predictions. + +```julia +function f_doubleMM_sites(θc::ComponentMatrix, xPc::ComponentMatrix) + # extract several covariates from xP + ST = typeof(getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (getdata(xPc[:S1,:])::ST) + S2 = (getdata(xPc[:S2,:])::ST) + # + # extract the parameters as vectors that are row-repeated into a matrix + n_obs = size(S1, 1) + VT = typeof(getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par - # vector will be repeated when broadcasted by a matrix - CA.getdata(θc[:, par]) ::VT + p1 = getdata(θc[:, par]) ::VT + repeat(p1', n_obs) # matrix: same for each concentration row in S1 end - # r0 = CA.getdata(θc[:,:r0]) # vector will be repeated when broadcasted by a matrix - # r1 = CA.getdata(θc[:,:r1]) - # K1 = CA.getdata(θc[:,:K1]) - # K2 = CA.getdata(θc[:,:K2]) - y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) - return (y) + # + # each variable is a matrix (n_obs x n_site) + r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) +end +``` +""" +function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) + # extract several covariates from xP + ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (CA.getdata(xPc[:S1,:])::ST) + S2 = (CA.getdata(xPc[:S2,:])::ST) + # + # extract the parameters as vectors that are row-repeated into a matrix + n_obs = size(S1, 1) + VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + p1 = CA.getdata(θc[:, par]) ::VT + repeat(p1', n_obs) # matrix: same for each concentration row in S1 + end + # + # each variable is a matrix (n_obs x n_site) + r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) end +# function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) +# # extract the parameters as vectors +# VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing +# (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par +# CA.getdata(θc[:, par]) ::VT +# end +# # +# # extract several covariates from xP +# # S1 = (xPc[:S1,:])' # transform site-last -> site-first dimension +# # S2 = (xPc[:S2,:])' +# #Main.@infiltrate_main + +# ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing +# S1 = (CA.getdata(xPc[:S1,:])::ST)' # transform site-last -> site-first dimension +# S2 = (CA.getdata(xPc[:S2,:])::ST)' +# # +# y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) +# return (CA.getdata(y)') # transform site-first -> site-last dimension +# end + + + +# function f_doubleMM( +# θ::AbstractMatrix{T}, x; intθ::HVI.AbstractComponentArrayInterpreter) where T +# # provide θ for n_row sites +# # provide x.S1 as Matrix n_site x n_obs +# # extract parameters not depending on order, i.e whether they are in θP or θM +# θc = intθ(θ) +# @assert size(x.S1, 1) == size(θ, 1) # same number of sites +# @assert size(x.S1) == size(x.S2) # same number of observations +# #@assert length(x.s2 == n_obs) +# # problems on AD on GPU with indexing CA may be related to printing result, use ";" +# VT = typeof(θ[:,1]) # workaround for non-stable Symbol-indexing CAMatrix +# #VT = first(Base.return_types(getindex, Tuple{typeof(θ),typeof(Colon()),typeof(1)})) +# (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par +# # vector will be repeated when broadcasted by a matrix +# CA.getdata(θc[:, par]) ::VT +# end +# # r0 = CA.getdata(θc[:,:r0]) # vector will be repeated when broadcasted by a matrix +# # r1 = CA.getdata(θc[:,:r1]) +# # K1 = CA.getdata(θc[:,:K1]) +# # K2 = CA.getdata(θc[:,:K2]) +# y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) +# return (y) +# end + # function f_doubleMM(θ::AbstractMatrix, x::NamedTuple, θpos::NamedTuple) # # provide θ for n_row sites # # provide x.S1 as Matrix n_site x n_obs @@ -87,6 +184,10 @@ function HVI.get_hybridproblem_par_templates( ::DoubleMMCase; scenario::Val{scen}) where {scen} if (:omit_r0 ∈ scen) #return ((; θP = θP_nor0, θM, θf = θP[(:K2r)])) + if (:K1global ∈ scen) + # scenario of K1 global but K2 site-dependent to inspect correlations^ + return ((; θP = θP_nor0_K1, θM = θM_nor0_K1)) + end return ((; θP = θP_nor0, θM)) end #(; θP, θM, θf = eltype(θP)[]) @@ -135,6 +236,9 @@ end function HVI.get_hybridproblem_pbmpar_covars( ::DoubleMMCase; scenario::Val{scen}) where {scen} if (:covarK2 ∈ scen) + if (:K1global ∈ scen) + return (:K1,) + end return (:K2,) end () @@ -166,59 +270,54 @@ end # (; n_covar, n_batch, n_θM, n_θP) # end -# function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = (), -# gdev = :f_on_gpu ∈ scenario ? gpu_device() : identity, -# ) -# #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers -# par_templates = get_hybridproblem_par_templates(prob; scenario) -# intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) -# let θFix = gdev(θFix), intθ = get_concrete(intθ) -# function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) -# pred_sites = map_f_each_site(f_doubleMM, θMs, θP, θFix, xP, intθ) -# pred_global = eltype(pred_sites)[] -# return pred_global, pred_sites -# end -# end -# end - # defining the PBmodel as a closure with let leads to problems of JLD2 reloading # Define all the variables additional to the ones passed curing the call by # a dedicated Closure object and define the PBmodel as a callable -struct DoubleMMCaller{CLT} - cl::CLT -end +# struct DoubleMMCaller{CLT} +# cl::CLT +# end -function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario, kwargs...) +function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; use_all_sites=false, scenario::Val{scen}) where {scen} # θall defined in this module above - cl = HVI.PBmodelClosure(prob; scenario, θall, int_xP1, kwargs...) - return DoubleMMCaller{typeof(cl)}(cl) + # TODO check and test for population or sites, currently return only site specific + pt = get_hybridproblem_par_templates(prob; scenario) + keys_fixed = Tuple(k for k in setdiff(keys(θall), (keys(pt.θP)..., keys(pt.θM)...))) + θFix = isempty(keys_fixed) ? CA.ComponentVector{eltype(θall)}() : θall[keys_fixed] + xPvec = int_xP1(vcat(xP_S1, xP_S2)) + if (:useSitePBM ∈ scen) + PBMSiteApplicator(f_doubleMM; pt.θP, pt.θM, θFix, xPvec) + else + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + n_site_batch = use_all_sites ? n_site : n_batch + PBMPopulationApplicator(f_doubleMM_sites, n_site_batch; pt.θP, pt.θM, θFix, xPvec) + end end -function(caller::DoubleMMCaller)(θP::AbstractVector, θMs::AbstractMatrix, xP) - cl = caller.cl - @assert size(xP, 2) == cl.n_site_batch - @assert size(θMs, 1) == cl.n_site_batch - # # convert vector of tuples to tuple of matricesByRows - # # need to supply xP as vectorOfTuples to work with DataLoader - # # k = first(keys(xP[1])) - # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k - # #stack(map(r -> r[k], xP))' - # stack(map(r -> r[k], xP); dims = 1) - # end)...) - #xPM = map(transpose, xPM1) - #xPc = int_xPb(CA.getdata(xP)) - #xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote - # make sure the same order of columns as in intθ - # reshape big matrix into NamedTuple of drivers S1 and S2 - # for broadcasting need sites in rows - #xPM = map(p -> CA.getdata(xP[p,:])', pos_xP) - xPM = map(p -> CA.getdata(xP)'[:, p], cl.pos_xP) - θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? cl.θFix_dev : cl.θFix - θ = hcat(CA.getdata(θP[cl.isP]), CA.getdata(θMs), θFixd) - pred_sites = f_doubleMM(θ, xPM; cl.intθ)' - pred_global = eltype(pred_sites)[] - return pred_global, pred_sites -end +# function(caller::DoubleMMCaller)(θP::AbstractVector, θMs::AbstractMatrix, xP) +# cl = caller.cl +# @assert size(xP, 2) == cl.n_site_batch +# @assert size(θMs, 1) == cl.n_site_batch +# # # convert vector of tuples to tuple of matricesByRows +# # # need to supply xP as vectorOfTuples to work with DataLoader +# # # k = first(keys(xP[1])) +# # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k +# # #stack(map(r -> r[k], xP))' +# # stack(map(r -> r[k], xP); dims = 1) +# # end)...) +# #xPM = map(transpose, xPM1) +# #xPc = int_xPb(CA.getdata(xP)) +# #xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote +# # make sure the same order of columns as in intθ +# # reshape big matrix into NamedTuple of drivers S1 and S2 +# # for broadcasting need sites in rows +# #xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)get_hybridproblem_PBmodel +# xPM = map(p -> CA.getdata(xP)'[:, p], cl.pos_xP) +# θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? cl.θFix_dev : cl.θFix +# θ = hcat(CA.getdata(θP[cl.isP]), CA.getdata(θMs), θFixd) +# pred_sites = f_doubleMM(θ, xPM; cl.intθ)' +# pred_global = eltype(pred_sites)[] +# return pred_global, pred_sites +# end function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::Val) neg_logden_indep_normal @@ -266,7 +365,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,)) # normalize to be distributed around the prescribed true values θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1))) - f = get_hybridproblem_PBmodel(prob; scenario, gdev = identity, use_all_sites = true) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) #xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site) int_xP_sites = ComponentArrayInterpreter(int_xP1, (n_site,)) xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site))) @@ -292,3 +391,15 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; ) end +function HVI.get_hybridproblem_cor_ends(prob::DoubleMMCase; scenario::Val{scen}) where {scen} + pt = get_hybridproblem_par_templates(prob; scenario) + if (:neglect_cor ∈ scen) + # one block for each parameter, i.e. neglect all correlations + (P = 1:length(pt.θP), M = 1:length(pt.θM)) + else + # single big blocks + (P = [length(pt.θP)], M = [length(pt.θM)]) + end +end + + diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 98d4fc9..c87fdbc 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -1,3 +1,28 @@ +""" +Implements [`AbstractHybridProblem`](@ref) by gathering all the parts into +one struct. + +Fields: +- `θP::ComponentVector`, `θM::ComponentVector`: parameter templates +- `g::AbstractModelApplicator`, `ϕg::AbstractVector`: ML model and its parameters +- `ϕunc::ComponentVector`: parameters for the Covariance matrix of the approximate posterior +- `f_batch`: Process-based model predicing for `n_batch` sites +- `f_allsites`: Process-based model predicing for `n_site` sites +- `priors`: AbstractDict: Prior distributions for all PBM parameters on constrained scale +- `py`: Likelihood function +- `transM::Stacked`, `transP::Stacked`: bijectors transforming from unconstrained to + constrained scale for site-specific and global parameters respectively. +- `train_dataloader::MLUtils.DataLoader`: providingn Tuple of matrices + `(xM, xP, y_o, y_unc, i_sites)`: covariates, model drivers, observations, + observation uncertainties and index of provided sites. +- `n_covar::Int`, `n_site::Int`, `n_batch::Int`: number covariates, + number of sites, and number of sites within one batch +- `cor_ends::NamedTuple`: block structure in correlations, + defaults to `(P = [length(θP)], M = [length(θM)])` +- `pbm_covars::NTuple{N,Symbol}`: names of global parameters used as covariates + in the ML model, defaults to `()`, i.e. no covariates fed into the ML model + +""" struct HybridProblem <: AbstractHybridProblem θP::CA.ComponentVector θM::CA.ComponentVector @@ -20,7 +45,6 @@ struct HybridProblem <: AbstractHybridProblem function HybridProblem( θP::CA.ComponentVector, θM::CA.ComponentVector, g::AbstractModelApplicator, ϕg::AbstractVector, - ϕunc::CA.ComponentVector, f_batch, f_allsites, priors::AbstractDict, @@ -33,7 +57,8 @@ struct HybridProblem <: AbstractHybridProblem n_site::Int, n_batch::Int, cor_ends::NamedTuple = (P = [length(θP)], M = [length(θM)]), - pbm_covars::NTuple{N,Symbol} = () + pbm_covars::NTuple{N,Symbol} = (), + ϕunc::CA.ComponentVector = init_hybrid_ϕunc(cor_ends, zero(eltype(θM))), ) where N new( θP, θM, f_batch, f_allsites, g, ϕg, ϕunc, priors, py, transM, transP, cor_ends, @@ -41,57 +66,73 @@ struct HybridProblem <: AbstractHybridProblem end end -function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, - # note no ϕg argument and g_chain unconstrained - g_chain, f_batch, - args...; rng = Random.default_rng(), kwargs...) - # dispatches on type of g_chain - g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM)) - HybridProblem(θP, θM, g, ϕg, f_batch, args...; kwargs...) -end +# function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, +# # note no ϕg argument and g_chain unconstrained +# g_chain, f_batch, +# args...; rng = Random.default_rng(), kwargs...) +# # dispatches on type of g_chain +# g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM)) +# HybridProblem(θP, θM, g, ϕg, f_batch, args...; kwargs...) +# end -function HybridProblem(prob::AbstractHybridProblem; scenario = ()) - (; θP, θM) = get_hybridproblem_par_templates(prob; scenario) - g, ϕg = get_hybridproblem_MLapplicator(prob; scenario) - ϕunc = get_hybridproblem_ϕunc(prob; scenario) - f_batch = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) - f_allsites = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) - py = get_hybridproblem_neg_logden_obs(prob; scenario) - (; transP, transM) = get_hybridproblem_transforms(prob; scenario) - train_dataloader = get_hybridproblem_train_dataloader(prob; scenario) - cor_ends = get_hybridproblem_cor_ends(prob; scenario) - pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) - priors = get_hybridproblem_priors(prob; scenario) - n_covar = get_hybridproblem_n_covar(prob; scenario) - n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - HybridProblem(θP, θM, g, ϕg, ϕunc, f_batch, f_allsites, priors, py, transM, transP, train_dataloader, - n_covar, n_site, n_batch, cor_ends, pbm_covars) -end +""" + HybridProblem(prob::AbstractHybridProblem; scenario = () -function update(prob::HybridProblem; - θP::CA.ComponentVector = prob.θP, - θM::CA.ComponentVector = prob.θM, - g::AbstractModelApplicator = prob.g, - ϕg::AbstractVector = prob.ϕg, - ϕunc::CA.ComponentVector = prob.ϕunc, - f_batch = prob.f_batch, - f_allsites = prob.f_allsites, - priors::AbstractDict = prob.priors, - py = prob.py, - # transM::Union{Function, Bijectors.Transform} = prob.transM, - # transP::Union{Function, Bijectors.Transform} = prob.transP, - transM = prob.transM, - transP = prob.transP, - cor_ends::NamedTuple = prob.cor_ends, - pbm_covars::NTuple{N,Symbol} = prob.pbm_covars, - train_dataloader::MLUtils.DataLoader = prob.train_dataloader, - n_covar::Integer = prob.n_covar, - n_site::Integer = prob.n_site, - n_batch::Integer = prob.n_batch, -) where N - HybridProblem(θP, θM, g, ϕg, ϕunc, f_batch, f_allsites, priors, py, transM, transP, - train_dataloader, n_covar, n_site, n_batch, cor_ends, pbm_covars) -end +Gather all information from another `AbstractHybridProblem` with possible +updating of some of the entries. +""" +function HybridProblem(prob::AbstractHybridProblem; scenario = (), + θP = get_hybridproblem_par_templates(prob; scenario).θP, + θM = get_hybridproblem_par_templates(prob; scenario).θM, + g = get_hybridproblem_MLapplicator(prob; scenario)[1], + ϕg = get_hybridproblem_MLapplicator(prob; scenario)[2], + f_batch = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false), + f_allsites = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true), + priors = get_hybridproblem_priors(prob; scenario), + py = get_hybridproblem_neg_logden_obs(prob; scenario), + transP = get_hybridproblem_transforms(prob; scenario).transP, + transM = get_hybridproblem_transforms(prob; scenario).transM, + train_dataloader = get_hybridproblem_train_dataloader(prob; scenario), + n_covar = get_hybridproblem_n_covar(prob; scenario), + n_site = get_hybridproblem_n_site_and_batch(prob; scenario)[1], + n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)[2], + cor_ends = get_hybridproblem_cor_ends(prob; scenario), + pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario), + ϕunc = get_hybridproblem_ϕunc(prob; scenario), + ) + HybridProblem(θP, θM, g, ϕg, f_batch, f_allsites, priors, py, transM, transP, train_dataloader, + n_covar, n_site, n_batch, cor_ends, pbm_covars, ϕunc) +end + +# """ +# update(prob::HybridProblem; ...) + +# Create a copy of prob, with some parts replaced. +# """ +# function update(prob::HybridProblem; +# θP::CA.ComponentVector = prob.θP, +# θM::CA.ComponentVector = prob.θM, +# g::AbstractModelApplicator = prob.g, +# ϕg::AbstractVector = prob.ϕg, +# ϕunc::CA.ComponentVector = prob.ϕunc, +# f_batch = prob.f_batch, +# f_allsites = prob.f_allsites, +# priors::AbstractDict = prob.priors, +# py = prob.py, +# # transM::Union{Function, Bijectors.Transform} = prob.transM, +# # transP::Union{Function, Bijectors.Transform} = prob.transP, +# transM = prob.transM, +# transP = prob.transP, +# cor_ends::NamedTuple = prob.cor_ends, +# pbm_covars::NTuple{N,Symbol} = prob.pbm_covars, +# train_dataloader::MLUtils.DataLoader = prob.train_dataloader, +# n_covar::Integer = prob.n_covar, +# n_site::Integer = prob.n_site, +# n_batch::Integer = prob.n_batch, +# ) where N +# HybridProblem(θP, θM, g, ϕg, f_batch, f_allsites, priors, py, transM, transP, +# train_dataloader, n_covar, n_site, n_batch, cor_ends, pbm_covars, ϕunc) +# end function get_hybridproblem_par_templates(prob::HybridProblem; scenario = ()) (; θP = prob.θP, θM = prob.θM) diff --git a/src/HybridSolver.jl b/src/HybridSolver.jl index f33e632..dbf2b76 100644 --- a/src/HybridSolver.jl +++ b/src/HybridSolver.jl @@ -8,8 +8,7 @@ HybridPointSolver(; alg) = HybridPointSolver(alg) function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolver; scenario, rng=Random.default_rng(), - gdev=:use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, - cdev=gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, + gdevs = get_gdev_MP(scenario), is_inferred::Val{is_infer} = Val(false), kwargs... ) where is_infer @@ -22,6 +21,7 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve #ϕ0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true ϕ0_cpu = vcat(ϕg0, apply_preserve_axes(inverse(transP), par_templates.θP)) train_loader = get_hybridproblem_train_dataloader(prob; scenario) + gdev = gdevs.gdev_M if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ0_cpu) g_dev = gdev(g) @@ -37,7 +37,7 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) #intP = ComponentArrayInterpreter(par_templates.θP) loss_gf = get_loss_gf(g_dev, transM, transP, f, y_global_o, intϕ; - cdev, pbm_covars, n_site_batch=n_batch) + cdev=infer_cdev(gdevs), pbm_covars, n_site_batch=n_batch) # call loss function once l1 = is_infer ? Test.@inferred(loss_gf(ϕ0_dev, first(train_loader_dev)...))[1] : @@ -56,7 +56,7 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve res = Optimization.solve(optprob, solver.alg; kwargs...) ϕ = intϕ(res.u) θP = cpu_ca(apply_preserve_axes(transP, cpu_ca(ϕ).ϕP)) - probo = update(prob; ϕg=cpu_ca(ϕ).ϕg, θP) + probo = HybridProblem(prob; ϕg=cpu_ca(ϕ).ϕg, θP) (; ϕ, resopt=res, probo) end @@ -68,17 +68,42 @@ end function HybridPosteriorSolver(; alg, n_MC=12, n_MC_cap=n_MC) HybridPosteriorSolver(alg, n_MC, n_MC_cap) end -function update(solver::HybridPosteriorSolver; +function HybridPosteriorSolver(solver::HybridPosteriorSolver; alg=solver.alg, n_MC=solver.n_MC, n_MC_cap=n_MC) HybridPosteriorSolver(alg, n_MC, n_MC_cap) end +""" + solve(prob::AbstractHybridProblem, solver::HybridPosteriorSolver; ...) + +Perform the inversion of HVI Problem. + +Optional keyword arguments +- `scenario`: Scenario to query prob, defaults to `Val(())`. +- `rng`: Random generator, defaults to `Random.default_rng()`. +- `gdevs`: `NamedTuple` `(;gdev_M, gdev_P)` functions to move + computation and data of ML model on and PBM respectively + to gpu (e.g. `gpu_device()` or cpu (`identity`). + defaults to [`get_gdev_MP`](@ref)`(scenario)` +- `θmean_quant` default to `0.0`: deprecated +- `is_inferred`: set to `Val(true)` to activate type stability checks + +Returns a `NamedTuple` of +- `probo`: A copy of the HybridProblem, with updated optimized parameters +- `interpreters`: TODO +- `ϕ`: the optimized HVI parameters: a `ComponentVector` with entries + - `μP`: `ComponentVector` of the mean global PBM parameters at unconstrained scale + - `ϕg`: The MLmodel parameter vector, + - `unc`: `ComponentVector` of further uncertainty parameters +- `θP`: `ComponentVector` of the mean global PBM parameters at constrained scale +- `resopt`: the structure returned by `Optimization.solve`. It can contain + more information on convergence. +""" function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorSolver; - scenario::Val{scen}, rng=Random.default_rng(), - gdev=:use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, - cdev=gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, + scenario::Val{scen}=Val(()), rng=Random.default_rng(), + gdevs = get_gdev_MP(scenario), θmean_quant=0.0, is_inferred::Val{is_infer} = Val(false), kwargs... @@ -99,10 +124,10 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS transMs = StackedArray(transM, n_batch) # train_loader = get_hybridproblem_train_dataloader(prob; scenario) - if gdev isa MLDataDevices.AbstractGPUDevice - ϕ0_dev = gdev(ϕ) - g_dev = gdev(g) # zygote fails if gdev is a CPUDevice, although should be non-op - train_loader_dev = gdev_hybridproblem_dataloader(train_loader; scenario, gdev) + if gdevs.gdev_M isa MLDataDevices.AbstractGPUDevice + ϕ0_dev = gdevs.gdev_M(ϕ) + g_dev = gdevs.gdev_M(g) # zygote fails if gdev is a CPUDevice, although should be non-op + train_loader_dev = gdev_hybridproblem_dataloader(train_loader; gdevs) else ϕ0_dev = ϕ g_dev = g @@ -113,15 +138,18 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS priors_θP_mean, priors_θMs_mean = construct_priors_θ_mean( prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM, transP; - scenario, get_ca_int_PMs, gdev, cdev, pbm_covars) + scenario, get_ca_int_PMs, gdevs, pbm_covars) y_global_o = Float32[] # TODO loss_elbo = get_loss_elbo( g_dev, transP, transMs, f, py, y_global_o; - solver.n_MC, solver.n_MC_cap, cor_ends, priors_θP_mean, priors_θMs_mean, cdev, - pbm_covars, θP, int_unc, int_μP_ϕg_unc) + solver.n_MC, solver.n_MC_cap, cor_ends, priors_θP_mean, priors_θMs_mean, + cdev=infer_cdev(gdevs), pbm_covars, θP, int_unc, int_μP_ϕg_unc) # test loss function once - #Main.@infiltrate_main + # tmp = first(train_loader_dev) + # using ShareAdd + # @usingany Cthulhu + # @descend_code_warntype loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...) l0 = is_infer ? (Test.@inferred loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...)) : loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...) @@ -131,8 +159,8 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS res = Optimization.solve(optprob, solver.alg; kwargs...) ϕc = interpreters.μP_ϕg_unc(res.u) θP = cpu_ca(apply_preserve_axes(transP, ϕc.μP)) - probo = update(prob; ϕg=cpu_ca(ϕc).ϕg, θP=θP, ϕunc=cpu_ca(ϕc).unc) - (; ϕ=ϕc, θP, resopt=res, interpreters, probo) + probo = HybridProblem(prob; ϕg=cpu_ca(ϕc).ϕg, θP=θP, ϕunc=cpu_ca(ϕc).unc) + (; probo, interpreters, ϕ=ϕc, θP, resopt=res) end function fit_narrow_normal(θi, prior, θmean_quant) @@ -230,9 +258,11 @@ In order to let mean of θ stay close to initial point parameter estimates construct a prior on mean θ to a Normal around initial prediction. """ function construct_priors_θ_mean(prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM, transP; - scenario::Val{scen}, get_ca_int_PMs, gdev, cdev, pbm_covars) where {scen} + scenario::Val{scen}, get_ca_int_PMs, gdevs, pbm_covars) where {scen} iszero(θmean_quant) ? ([],[]) : begin + gdev=gdevs.gdev_M + #cdev=infer_cdev(gdevs) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) # all_loader = MLUtils.DataLoader( # get_hybridproblem_train_dataloader(prob; scenario).data, batchsize = n_site) diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 9ea9b57..7cf9d88 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -21,8 +21,12 @@ using StaticArrays: StaticArrays as SA using Functors using Test: Test # @inferred +export DoubleMM + export extend_stacked_nrow, StackedArray -#export Exp +#public Exp +#julia 1.10 public: https://github.com/JuliaLang/julia/pull/55097 +VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public Exp")) include("bijectors_utils.jl") export AbstractComponentArrayInterpreter, ComponentArrayInterpreter, @@ -36,6 +40,9 @@ export construct_3layer_MLApplicator, select_ml_engine export NullModelApplicator, MagnitudeModelApplicator, NormalScalingModelApplicator include("ModelApplicator.jl") +export AbstractPBMApplicator, NullPBMApplicator, PBMSiteApplicator, PBMPopulationApplicator +include("PBMApplicator.jl") + # export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler # include("GPUDataHandler.jl") @@ -50,11 +57,11 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_ get_hybridproblem_cor_ends, get_hybridproblem_priors, get_hybridproblem_pbmpar_covars, -#update, gen_cov_pred, construct_dataloader_from_synthetic, gdev_hybridproblem_dataloader, - setup_PBMpar_interpreter + setup_PBMpar_interpreter, + get_gdev_MP include("AbstractHybridProblem.jl") export AbstractHybridProblemInterpreters, HybridProblemInterpreters, @@ -67,7 +74,8 @@ export HybridProblem export get_quantile_transformed include("HybridProblem.jl") -export map_f_each_site, gf, get_loss_gf +export gf, get_loss_gf +#export map_f_each_site include("gf.jl") export compute_correlated_covars, scale_centered_at @@ -85,7 +93,7 @@ include("logden_normal.jl") export get_ca_starts, get_ca_ends, get_cor_count include("cholesky.jl") -export neg_elbo_gtf, predict_hvi +export neg_elbo_gtf, sample_posterior, apply_process_model, predict_hvi include("elbo.jl") export init_hybrid_params, init_hybrid_ϕunc diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 85a1140..943a147 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -95,9 +95,7 @@ end MagnitudeModelApplicator(app, y0) Wrapper around AbstractModelApplicator that multiplies the prediction -by the absolute inverse of an initial estimate of the prediction. - -This helps to keep raw predictions and weights in a similar magnitude. +of the wrapped `app` by scalar `y0`. """ struct MagnitudeModelApplicator{M,A} <: AbstractModelApplicator app::A @@ -137,18 +135,26 @@ end @functor NormalScalingModelApplicator """ -Fit a Normal distribution to iterators lower and upper. -If `repeat_inner` is given, each fitted distribution is repeated as many times. + NormalScalingModelApplicator(app, lowers, uppers, FT::Type; repeat_inner::Integer = 1) + +Fit a Normal distribution to number iterators `lower` and `upper` and transform +results of the wrapped `app` `AbstractModelApplicator`. +If `repeat_inner` is given, each fitted distribution is repeated as many times +to support independent multivariate normal distribution. + +`FT` is the specific FloatType to use to construct Distributions, +It usually corresponds to the type used in other ML-parts of the model, e.g. `Float32`. """ function NormalScalingModelApplicator( - app::AbstractModelApplicator, lowers::AbstractVector{<:Number}, uppers, ET::Type; repeat_inner::Integer = 1) + app::AbstractModelApplicator, lowers, uppers, FT::Type; + repeat_inner::Integer = 1) pars = map(lowers, uppers) do lower, upper dζ = fit(Normal, @qp_l(lower), @qp_u(upper)) params(dζ) end # use collect to make it an array that works with gpu - μ = repeat(collect(ET, first.(pars)); inner=(repeat_inner,)) - σ = repeat(collect(ET, last.(pars)); inner=(repeat_inner,)) + μ = repeat(collect(FT, first.(pars)); inner=(repeat_inner,)) + σ = repeat(collect(FT, last.(pars)); inner=(repeat_inner,)) NormalScalingModelApplicator(app, μ, σ) end diff --git a/src/PBMApplicator.jl b/src/PBMApplicator.jl new file mode 100644 index 0000000..a040840 --- /dev/null +++ b/src/PBMApplicator.jl @@ -0,0 +1,177 @@ +""" + AbstractPBMApplicator(θP::AbstractVector, θMs::AbstractMatrix, xP::AbstractMatrix) + +Abstraction of applying a process-based model with +global parameters, `x`, site-specific parameters, `θMs` (sites in columns), +and site-specific model drivers, `xP` (sites in columns), +It returns a matrix of predictions sites in columns. + +Specific implementations need to implement function `apply_model(app, θP, θMs, xP)`. +Provided are implementations +- `NullPBMApplicator`: returning its input `θMs` for testing +- `PBMSiteApplicator`: based on a function that computes predictions per site +- `PBMPopulationApplicator`: based on a function that computes predictions for entire population +""" +abstract type AbstractPBMApplicator end + +# function apply_model end # already defined in ModelApplicator.jl for ML model + +function (app::AbstractPBMApplicator)(θP::AbstractVector, θMs::AbstractMatrix, xP::AbstractMatrix) + apply_model(app, θP, θMs, xP) +end + + +""" + NullPBMApplicator() + +Process-Base-Model applicator that returns its θMs inputs. Used for testing. +""" +struct NullPBMApplicator <: AbstractPBMApplicator end + +function apply_model(app::NullPBMApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP) + return CA.getdata(θMs) +end + + +struct PBMSiteApplicator{F, IT, IXT, VFT} <: AbstractPBMApplicator + fθ::F + intθ1::IT + int_xPsite::IXT + θFix::VFT # can be a CuArray instead of a Vector +end + +""" + PBMSiteApplicator(fθ, n_batch; θP, θM, θFix, xPvec) + +Construct AbstractPBMApplicator from process-based model `fθ` that computes predictions +for a single site. +The Applicator combines enclosed `θFix`, with provided `θMs` and `θP` and +constructs a `ComponentVector` that can be indexed by +symbolic parameter names, corresponding to the templates provided during +construction of the applicator. + +## Arguments +- `fθ`: process model, process model `fθ(θc, xP)`, which is agnostic of the partitioning +of parameters. +- `θP`: `ComponentVector` template of global process model parameters +- `θM`: `ComponentVector` template of individual process model parameters +- `θFix`: `ComponentVector` of actual fixed process model parameters +- `xPvec`:`ComponentVector` template of model drivers for a single site +""" +function PBMSiteApplicator(fθ; + θP::CA.ComponentVector, θM::CA.ComponentVector, θFix::CA.ComponentVector, + xPvec::CA.ComponentVector + ) + intθ1 = get_concrete(ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))) + int_xPsite = get_concrete(ComponentArrayInterpreter(xPvec)) + PBMSiteApplicator(fθ, intθ1, int_xPsite, CA.getdata(θFix)) +end + +function apply_model(app::PBMSiteApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP) + function apply_PBMsite(θM, xP1) + if (CA.getdata(θP) isa GPUArraysCore.AbstractGPUArray) && + (!(app.θFix isa GPUArraysCore.AbstractGPUArray) || + !(CA.getdata(θMs) isa GPUArraysCore.AbstractGPUArray)) + error("concatenating GPUarrays with non-gpu arrays θFix or θMs. " * + "May fmap PBMModelapplicators to gdev, " * + "or compute PBMmodel on CPU") + end + θ = vcat(CA.getdata(θP), CA.getdata(θM), app.θFix) + θc = app.intθ1(θ); # show errors without ";" + xPc = app.int_xPsite(xP1); + ans = CA.getdata(app.fθ(θc, xPc)) + ans + end + # mapreduce-hcat is only typestable with init, which needs number of rows + # https://discourse.julialang.org/t/type-instability-of-mapreduce-vs-map-reduce/121136 + # local pred_sites = mapreduce( + # apply_PBMsite, hcat, eachrow(θMs), eachcol(xP); init=Matrix{Float64}(undef,n_obs,0)) + θMs1, it_θMs = if (CA.getdata(θP) isa GPUArraysCore.AbstractGPUArray) + # if working on CuArray, better materialize transpose and use eachcol for contiguous + # avoid eachrow, because it does produce non-strided views which are bad on GPU, + # https://discourse.julialang.org/t/using-view-with-cuarrays/104057/5 + # better compute on CPU or use matrix-version of PBMModel + θMst = copy(CA.getdata(θMs)') + Iterators.peel(eachcol(θMst)); + else + Iterators.peel(eachrow(CA.getdata(θMs))) + end + xP1, it_xP = Iterators.peel(eachcol(CA.getdata(xP))) + obs1 = apply_PBMsite(θMs1, xP1) + local pred_sites = mapreduce( + apply_PBMsite, hcat, it_θMs, it_xP; init=reshape(obs1, :, 1)) + # # special case of mapreduce producing a vector rather than a matrix + # pred_sites = !(pred_sites0 isa AbstractMatrix) ? hcat(pred_sites0) : pred_sites0 + #obs1 = apply_PBMsite(first(eachrow(θMs)), first(eachcol(xP))) + #obs_vecs = map(apply_PBMsite, eachrow(θMs), eachcol(xP)) + #obs_vecs = (apply_PBMsite(θMs1, xP1) for (θMs1, xP1) in zip(eachrow(θMs), eachcol(xP))) + #pred_sites = stack(obs_vecs; dims = 1) + #pred_sites = stack(obs_vecs) # does not work with Zygote + local pred_global = eltype(pred_sites)[] # TODO remove + return pred_global, pred_sites +end + +struct PBMPopulationApplicator{MFT, IPT, IT, IXT, F} <: AbstractPBMApplicator + fθpop::F + θFixm::MFT # may be CuMatrix rather than Matrix + isP::IPT #Matrix{Int} # transferred to CuMatrix? + intθ::IT + int_xP::IXT +end + +# let fmap not descend into isP +# @functor PBMPopulationApplicator (θFixm, ) + +""" + PBMPopulationApplicator(fθpop, n_batch; θP, θM, θFix, xPvec) + +Construct AbstractPBMApplicator from process-based model `fθ` that computes predictions +across sites for a population of size `n_batch`. +The applicator combines enclosed `θFix`, with provided `θMs` and `θP` +to a `ComponentMatrix` with parameters with one row for each site, that +can be column-indexed by Symbols. + +## Arguments +- `fθpop`: process model, process model `f(θc, xPc)`, which is agnostic of the partitioning + of parameters into fixed, global, and individual. + - `θc`: parameters: `ComponentMatrix` (n_site x n_par) with each row a parameter vector + - `xPc`: observations: `ComponentMatrix` (n_obs x n_site) with each column + observationsfor one site +- `n_batch`: number of indiduals, i.e. rows in `θMs` +- `θP`: `ComponentVector` template of global process model parameters +- `θM`: `ComponentVector` template of individual process model parameters +- `θFix`: `ComponentVector` of actual fixed process model parameters +- `xPvec`: `ComponentVector` template of model drivers for a single site +""" +function PBMPopulationApplicator(fθpop, n_batch; + θP::CA.ComponentVector, θM::CA.ComponentVector, θFix::CA.ComponentVector, + xPvec::CA.ComponentVector + ) + intθvec = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix))) + int_xP_vec = ComponentArrayInterpreter(xPvec) + isFix = repeat(axes(θFix, 1)', n_batch) + # + intθ = get_concrete(ComponentArrayInterpreter((n_batch,), intθvec)) + int_xP = get_concrete(ComponentArrayInterpreter(int_xP_vec, (n_batch,))) + isP = repeat(axes(θP, 1)', n_batch) + θFixm = CA.getdata(θFix[isFix]) + PBMPopulationApplicator(fθpop, θFixm, isP, intθ, int_xP) +end + +function apply_model(app::PBMPopulationApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP) + if (CA.getdata(θP) isa GPUArraysCore.AbstractGPUArray) && + (!(app.θFixm isa GPUArraysCore.AbstractGPUArray) || + !(CA.getdata(θMs) isa GPUArraysCore.AbstractGPUArray)) + error("concatenating GPUarrays with non-gpu arrays θFixm or θMs. " * + "May transfer PBMPopulationApplicator to gdev, " * + "or compute PBM on CPU.") + end + # repeat θP and concatenate with + local θ = hcat(CA.getdata(θP[app.isP]), CA.getdata(θMs), app.θFixm) + local θc = app.intθ(CA.getdata(θ)) + local xPc = app.int_xP(CA.getdata(xP)) + local pred_sites = app.fθpop(θc, xPc) + local pred_global = eltype(pred_sites)[] # TODO remove + return pred_global, pred_sites +end + diff --git a/src/bijectors_utils.jl b/src/bijectors_utils.jl index efc0c96..9061ebb 100644 --- a/src/bijectors_utils.jl +++ b/src/bijectors_utils.jl @@ -1,5 +1,12 @@ #------------------- Exp +""" + Exp() + +A bijector that applies broadcasted exponential function, i.e. `exp.(x)`. +It is equivalent to `elementwise(exp)` but works better with automatic +differentiation on GPU. +""" struct Exp <: Bijector end diff --git a/src/cholesky.jl b/src/cholesky.jl index 14af734..24fd444 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -59,9 +59,10 @@ end # """ # Convert vector v columnwise entries of upper diagonal matrix to UnitUppterTriangular - +# # Avoid using this repeatedly on GPU arrays, because it only works on CPU (scalar indexing). -# There is a fallback that pulls `v` to the CPU, applies, and pushes back to GPU. +# +# For v isa CuVector, see HybridVariationalInferenceCUDAExt # """ function _vec2uutri( v::AbstractVector{T}; n=invsumn(length(v)) + one(T), diag=one(T)) where {T} @@ -138,7 +139,7 @@ end #function uutri2vec(X::CUDA.CuMatrix{T}; kwargs...) where {T} """ -Takes a vector of entries of a lower UnitUpperTriangular matrix +Takes a vector of parameters for UnitUpperTriangular matrix and transforms it to an UpperTriangular that satisfies diag(U' * U) = 1. diff --git a/src/elbo.jl b/src/elbo.jl index 8bd4708..74520e2 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -97,6 +97,11 @@ function neg_elbo_ζtf(ζsP, ζsMs, σ, f, py, xP, y_ob, y_unc; y_pred_global, y_pred_i = f(θP, θMs, xP) # TODO nLogDen prior on \theta #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) + # Main.@infiltrate_main + # Test.@inferred( f(θP, θMs, xP) ) + # using ShareAdd + # @usingany Cthulhu + # @descend_code_warntype f(θP, θMs, xP) nLy1 = py(y_ob, y_pred_i, y_unc) nLy1 - logjac end @@ -113,7 +118,11 @@ function neg_elbo_ζtf(ζsP, ζsMs, σ, f, py, xP, y_ob, y_unc; # logdet_jacT2 = -sum_log_σ # log Prod(1/σ_i) = -sum log σ_i logdetΣ = 2 * sum(log.(σ)) n_θ = size(ζsP, 1) + prod(size(ζsMs)[1:2]) - @assert length(σ) == n_θ + if length(σ) != n_θ + error("TODO infiltrate") + #Main.@infiltrate_main + end + #@assert length(σ) == n_θ entropy_ζ = entropy_MvNormal(n_θ, logdetΣ) # defined in logden_normal # if i_sites[1] == 1 # #Main.@infiltrate_main @@ -124,26 +133,8 @@ function neg_elbo_ζtf(ζsP, ζsMs, σ, f, py, xP, y_ob, y_unc; nLy, entropy_ζ end -() -> begin - nLy = reduce( - +, map(eachcol(ζs_cpu[:, 1:n_MC])) do ζi - θ_i, y_pred_i, logjac = apply_f_trans(ζi, xP, f, transPMs, interpreters.PMs) - # TODO nLogDen prior on \theta - #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) - nLy1 = py(y_ob, y_pred_i, y_unc) - nLy1 - logjac - end) / n_MC -end - -() -> begin - # using UnicodePlots - histogram(nLys) -end - """ - predict_hvi([rng], prob::AbstractHybridProblem [,xM, xP]; scenario, ...) - predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix; - get_transPMs, get_ca_int_PMs, n_sample_pred=200, gdev = identity) + predict_hvi([rng], predict_hvi(rng, prob::AbstractHybridProblem) Prediction function for hybrid variational inference parameter model. @@ -155,8 +146,8 @@ Prediction function for hybrid variational inference parameter model. access parts of it, e.g. `xP[:S1,...]`. ## Keyword arguments -- scenario -- n_sample_pred +- `scenario` +- `n_sample_pred` Returns an NamedTuple `(; y, θsP, θsMs, entropy_ζ)` with entries - `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions. @@ -167,27 +158,71 @@ Returns an NamedTuple `(; y, θsP, θsMs, entropy_ζ)` with entries - `entropy_ζ`: The entropy of the log-determinant of the transformation of the set of model parameters, which is involved in uncertainty quantification. """ -function predict_hvi(rng, prob::AbstractHybridProblem; scenario, kwargs...) +function predict_hvi(rng, prob::AbstractHybridProblem; scenario=Val(()), + gdevs = get_gdev_MP(scenario), + kwargs... + ) dl = get_hybridproblem_train_dataloader(prob; scenario) - dl_dev = gdev_hybridproblem_dataloader(dl; scenario) - # predict for all sites + dl_dev = gdev_hybridproblem_dataloader(dl; gdevs) xM, xP = dl_dev.data[1:2] - predict_hvi(rng, prob, xM, xP; scenario, kwargs...) + (; θsP, θsMs, entropy_ζ) = sample_posterior(rng, prob, xM; scenario, gdevs, kwargs...) + # + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + n_site_pred = size(θsMs,1) + is_predict_batch = (n_site_pred == n_batch) + @assert size(xP, 2) == n_site_pred + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=!is_predict_batch) + y = apply_process_model(θsP, θsMs, f, xP) + (; y, θsP, θsMs, entropy_ζ) +end + +""" + sample_posterior(rng, prob, [xM::AbstractMatrix]; scenario=Val(()), kwargs...) + +Sampling the posterior parameter distribution +for hybrid variational inference problem. + +## Arguments +- `rng`: random number generator +- `prob`: The AbstractHybridProblem from to sample +- `xM`: covariates for the machine-learning model (ML): Matrix `(n_θM x n_site_pred)`. + Default to all sites in `get_hybridproblem_train_dataloader(prob; scenario)`. + +Optional keyword arguments +- `scenario`: scenario to query `prob` and set default of gpu devices. +- `n_sample_pred`: number of samples to draw, defaults to 200 +- `gdevs`: `NamedTuple(gdev_M, gdev_P)`: GPU devices for machine learning model + and parameter transformtation, default to [`get_gdev_MP`](@ref)`(scenario)`. +- `is_inferred`: set to `Val(true)` to activate type stabilicy check for transformation + +Returns an NamedTuple `(; θsP, θsMs, entropy_ζ)` with entries +- `θsP`: ComponentArray `(n_θP, n_sample_pred)` of PBM model parameters + that are kept constant across sites. +- `θsMs`: ComponentArray `(n_site, n_θM, n_sample_pred)` of PBM model parameters + that vary by site. +- `entropy_ζ`: The entropy of the log-determinant of the transformation of + the set of model parameters, which is involved in uncertainty quantification. +""" +function sample_posterior(rng, prob::AbstractHybridProblem; scenario=Val(()), + gdevs = get_gdev_MP(scenario), + kwargs...) + dl = get_hybridproblem_train_dataloader(prob; scenario) + dl_dev = gdev_hybridproblem_dataloader(dl; gdevs) + xM = dl_dev.data[1] + sample_posterior(rng, prob, xM; scenario, gdevs, kwargs...) end -function predict_hvi(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; - scenario, + + +function sample_posterior(rng, prob::AbstractHybridProblem, xM::AbstractMatrix; + scenario=Val(()), n_sample_pred=200, - gdev=:use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, - cdev=!(gdev isa MLDataDevices.AbstractGPUDevice) ? identity : - (:f_on_gpu ∈ _val_value(scenario) ? identity : cpu_device()), + gdevs = get_gdev_MP(scenario), kwargs... ) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - is_predict_batch = (n_batch == length(xP)) + is_predict_batch = (n_batch == size(xM,2)) n_site_pred = is_predict_batch ? n_batch : n_site - @assert size(xP, 2) == n_site_pred @assert size(xM, 2) == n_site_pred - f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=!is_predict_batch) par_templates = get_hybridproblem_par_templates(prob; scenario) (; θP, θM) = par_templates cor_ends = get_hybridproblem_cor_ends(prob; scenario) @@ -202,70 +237,86 @@ function predict_hvi(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; int_μP_ϕg_unc = interpreters.μP_ϕg_unc int_unc = interpreters.unc transMs = StackedArray(transM, n_batch) - g_dev, ϕ_dev = gdev(g), gdev(ϕ) - predict_hvi(rng, g_dev, f, ϕ_dev, xM, xP; + g_dev, ϕ_dev = gdevs.gdev_M(g), gdevs.gdev_M(ϕ) + (; θsP, θsMs, entropy_ζ) = sample_posterior(rng, g_dev, ϕ_dev, xM; int_μP_ϕg_unc, int_unc, transP, transM, - n_sample_pred, cdev, cor_ends, pbm_covar_indices, kwargs...) + n_sample_pred, cdev=infer_cdev(gdevs), cor_ends, pbm_covar_indices, kwargs...) + θsPc = ComponentArrayInterpreter(prob.θP, (n_sample_pred,))(θsP) + θsMsc = ComponentArrayInterpreter((n_site,), prob.θM, (n_sample_pred,))(θsMs) + (; θsP=θsPc, θsMs=θsMsc, entropy_ζ) end -function predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP; +function sample_posterior(rng, g, ϕ::AbstractVector, xM::AbstractMatrix; int_μP_ϕg_unc::AbstractComponentArrayInterpreter, int_unc::AbstractComponentArrayInterpreter, transP, transM, - n_sample_pred=200, - cdev=cpu_device(), + n_sample_pred, + cdev, cor_ends, pbm_covar_indices, is_inferred::Val{is_infer} = Val(false), kwargs... ) where is_infer - ζsP, ζsMs, σ = generate_ζ(rng, g, CA.getdata(ϕ), CA.getdata(xM); + ζsP_gpu, ζsMs_gpu, σ = generate_ζ(rng, g, CA.getdata(ϕ), CA.getdata(xM); int_μP_ϕg_unc, int_unc, n_MC=n_sample_pred, cor_ends, pbm_covar_indices) - ζsP_cpu = cdev(ζsP) - ζsMs_cpu = cdev(ζsMs) + ζsP = cdev(ζsP_gpu) + ζsMs = cdev(ζsMs_gpu) logdetΣ = 2 * sum(log.(σ)) entropy_ζ = entropy_MvNormal(length(σ), logdetΣ) - res_pred = is_infer ? - apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) : - Test.@inferred apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) - (; res_pred..., entropy_ζ) + trans_mP = StackedArray(transP, size(ζsP, 2)) + trans_mMs = StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3)) + θsP, θsMs = is_infer ? + Test.@inferred(transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs)) : + transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) + # res_pred = is_infer ? + # apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) : + # Test.@inferred apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) + (; θsP, θsMs, entropy_ζ) end -""" -Compute predictions of the transformation at given -transformed parameters for each site. -The number of sites is given by the number of rows in `ζsMs`. +# """ +# Compute predictions of the transformation at given +# transformed parameters for each site. +# The number of sites is given by the number of rows in `ζsMs`. + +# Steps: +# - transform the parameters to original constrained space +# - Applies the mechanistic model for each site + +# `ζsP` and `ζsMs` are shaped according to the output of `generate_ζ`. +# Results are of shape `(n_obs x n_site_pred x n_MC)`. +# """ +# function apply_f_trans(ζsP::AbstractMatrix, ζsMs::AbstractArray, f, xP; +# transP, transM::Stacked, +# trans_mP=StackedArray(transP, size(ζsP, 2)), +# trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3)) +# ) +# θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) +# y = apply_process_model(θsP, θsMs, f, xP) +# (; y, θsP, θsMs) +# end -Steps: -- transform the parameters to original constrained space -- Applies the mechanistic model for each site +# function apply_f_trans(ζP::AbstractVector, ζMs::AbstractMatrix, f, xP; +# transP, transM::Stacked, transMs::StackedArray=StackedArray(transM, size(ζMs, 1)), +# ) +# θP = transP(ζP) +# θMs = transMs(ζMs) +# y_global, y = f(θP, θMs, xP) +# (; y, θP, θMs) +# end -`ζsP` and `ζsMs` are shaped according to the output of `generate_ζ`. -Results are of shape `(n_obs x n_site_pred x n_MC)`. """ -function apply_f_trans(ζsP::AbstractMatrix, ζsMs::AbstractArray, f, xP; - transP, transM::Stacked, - trans_mP=StackedArray(transP, size(ζsP, 2)), - trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3)) -) - θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) - y = apply_f(θsP, θsMs, f, xP) - (; y, θsP, θsMs) -end - -function apply_f_trans(ζP::AbstractVector, ζMs::AbstractMatrix, f, xP; - transP, transM::Stacked, transMs::StackedArray=StackedArray(transM, size(ζMs, 1)), -) - θP = transP(ζP) - θMs = transMs(ζMs) - y_global, y = f(θP, θMs, xP) - (; y, θP, θMs) -end + apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) +Call a PBM applicator for a sample of parameters of each site, and stack results -function apply_f(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET +`θsP` and `θsMs` are shaped according to the output of `generate_ζ`, i.e. +`(n_site_pred x n_par x n_MC)`. +Results are of shape `(n_obs x n_site_pred x n_MC)`. +""" +function apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET y_pred = stack(map(eachcol(θsP), eachslice(θsMs, dims=3)) do θP, θMs y_global, y_pred_i = f(θP, θMs, xP) y_pred_i @@ -540,3 +591,8 @@ function flatten_hybrid_pars(xsP::AbstractMatrix{FT}, xsMs::AbstractArray{FT,3}) end + + + + + diff --git a/src/gf.jl b/src/gf.jl index 44ee7eb..11c552a 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -1,43 +1,43 @@ # Point solver where ML directly predicts PBL parameters, rather than their # distribution. -""" -Map process base model (PBM), `f`, across each site. +# """ +# Map process base model (PBM), `f`, across each site. -## Arguments -- `f(θ, xP, args...; intθ1, kwargs...)`: Process based model for single site +# ## Arguments +# - `f(θ, xP, args...; intθ1, kwargs...)`: Process based model for single site - Make sure to hint the type, so that results can be inferred. -- `θMst`: transposed model parameters across sites matrix: (n_parM, n_site_batch) -- `θP`: transposed model parameters that do not differ by site: (n_parP,) -- `θFix`: Further parameter required by f that are not calibrated. -- `xP`: Model drivers: Matrix with n_site_batch columns. - If provided a ComponentArray with labeled rows, f can then access `xP[:key]`. -- `intθ1`: ComponentArrayInterpreter that can be applied to θ, - so that entries can be extracted. +# Make sure to hint the type, so that results can be inferred. +# - `θMst`: transposed model parameters across sites matrix: (n_parM, n_site_batch) +# - `θP`: transposed model parameters that do not differ by site: (n_parP,) +# - `θFix`: Further parameter required by f that are not calibrated. +# - `xP`: Model drivers: Matrix with n_site_batch columns. +# If provided a ComponentArray with labeled rows, f can then access `xP[:key]`. +# - `intθ1`: ComponentArrayInterpreter that can be applied to θ, +# so that entries can be extracted. -See test_HybridProblem of using this function to construct a PBM function that -can predict across all sites. -""" -function map_f_each_site( - f, θMst::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP, args...; - intθ1::AbstractComponentArrayInterpreter, kwargs... -) - # predict several sites with same global parameters θP and fixed parameters θFix - it1 = eachcol(CA.getdata(θMst)) - it2 = eachcol(xP) - _θM = first(it1) - _x_site = first(it2) - TXS = typeof(_x_site) - TY = typeof(f(vcat(θP, _θM, θFix), _x_site, args...; intθ1, kwargs...)) - #TY = typeof(f(vcat(θP, _θM, θFix), _x_site; intθ1)) - yv = map(it1, it2) do θM, x_site - x_site_typed = x_site::TXS - f(vcat(θP, θM, θFix), x_site_typed, args...; intθ1, kwargs...) - end::Vector{TY} - y = stack(yv) - return(y) -end +# See test_HybridProblem of using this function to construct a PBM function that +# can predict across all sites. +# """ +# function map_f_each_site( +# f, θMst::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP, args...; +# intθ1::AbstractComponentArrayInterpreter, kwargs... +# ) +# # predict several sites with same global parameters θP and fixed parameters θFix +# it1 = eachcol(CA.getdata(θMst)) +# it2 = eachcol(xP) +# _θM = first(it1) +# _x_site = first(it2) +# TXS = typeof(_x_site) +# TY = typeof(f(vcat(θP, _θM, θFix), _x_site, args...; intθ1, kwargs...)) +# #TY = typeof(f(vcat(θP, _θM, θFix), _x_site; intθ1)) +# yv = map(it1, it2) do θM, x_site +# x_site_typed = x_site::TXS +# f(vcat(θP, θM, θFix), x_site_typed, args...; intθ1, kwargs...) +# end::Vector{TY} +# y = stack(yv) +# return(y) +# end # function map_f_each_site(f, θMs::AbstractMatrix, θPs::AbstractMatrix, θFix::AbstractVector, xP, args...; kwargs...) # # do not call f with matrix θ, because .* with vectors S1 would go wrong # yv = map(eachcol(θMs), eachcol(θPs), xP) do θM, θP, xP_site diff --git a/src/util_ca.jl b/src/util_ca.jl index 63561da..5336cd2 100644 --- a/src/util_ca.jl +++ b/src/util_ca.jl @@ -3,8 +3,12 @@ Move ComponentArray form gpu to cpu. """ -function cpu_ca end +#function cpu_ca end # define in FluxExt +function cpu_ca(ca::CA.ComponentArray) + CA.ComponentArray(cpu_device()(CA.getdata(ca)), CA.getaxes(ca)) +end + function apply_preserve_axes(f, ca::CA.ComponentArray) CA.ComponentArray(f(CA.getdata(ca)), CA.getaxes(ca)) diff --git a/test/Project.toml b/test/Project.toml index d1e29ea..81c469c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" diff --git a/test/runtests.jl b/test/runtests.jl index c41db77..7247a5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,7 +42,7 @@ end @time begin if GROUP == "All" || GROUP == "Aqua" #@safetestset "test" include("test/test_aqua.jl") - if VERSION >= VersionNumber("1.11.2") + if VersionNumber("1.11.2") <= VERSION < VersionNumber("1.12") #@safetestset "test" include("test/test_aqua.jl") @time @safetestset "test_aqua" include("test_aqua.jl") end diff --git a/test/test_ComponentArrayInterpreter.jl b/test/test_ComponentArrayInterpreter.jl index 1286bf5..bd9a765 100644 --- a/test/test_ComponentArrayInterpreter.jl +++ b/test/test_ComponentArrayInterpreter.jl @@ -80,9 +80,9 @@ end; end; @testset "ComponentArrayInterpreter matrix and array" begin - mv = ComponentArrayInterpreter(; c1=2, c2=3) - #mv = ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3)) - cv = mv(1:length(mv)) + mvi = ComponentArrayInterpreter(; c1=2, c2=3) + #mvi = ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3)) + cv = mvi(1:length(mvi)) n_col = 4 mm = ComponentArrayInterpreter(cv, (n_col,)) # 1-tuple testm = (m) -> begin @@ -94,7 +94,7 @@ end; testm(mm) mmc = get_concrete(mm) testm(mmc) - mmi = ComponentArrayInterpreter(mv, (n_col,)) # construct on interpreter itself + mmi = ComponentArrayInterpreter(mvi, (n_col,)) # construct on interpreter itself testm(mmi) # n_z = 3 @@ -106,7 +106,7 @@ end; end testm(mm) testm(get_concrete(mm)) - mmi = ComponentArrayInterpreter(mv, (n_col, n_z)) # construct on interpreter itself + mmi = ComponentArrayInterpreter(mvi, (n_col, n_z)) # construct on interpreter itself testm(mmi) # n_row = 3 @@ -118,17 +118,17 @@ end; end testm(mm) testm(get_concrete(mm)) - mm = ComponentArrayInterpreter((n_row,), mv) # construct on interpreter itself + mm = ComponentArrayInterpreter((n_row,), mvi) # construct on interpreter itself testm(mmi) end; @testset "stack_ca_int" begin - mv = get_concrete(ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3))) - #mv = ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3)) - cv = mv(1:length(mv)) + mvi = get_concrete(ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3))) + #mvi = ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3)) + cv = mvi(1:length(mvi)) n_col = 4 n_dims = (n_col,) - mm = @inferred CP.stack_ca_int(mv, Val((n_col,))) # 1-tuple + mm = @inferred CP.stack_ca_int(mvi, Val((n_col,))) # 1-tuple @inferred get_positions(mm) # sizes are inferred here testm = (m) -> begin @test length(mm) == length(cv) * n_col @@ -139,7 +139,7 @@ end; testm(mm) # n_z = 3 - mm = @inferred stack_ca_int(mv, Val((n_col, n_z))) + mm = @inferred stack_ca_int(mvi, Val((n_col, n_z))) testm = (m) -> begin @test mm isa AbstractComponentArrayInterpreter @test length(mm) == length(cv) * n_col * n_z @@ -149,23 +149,23 @@ end; testm(mm) # n_row = 3 - mm = @inferred stack_ca_int(Val((n_row,)), mv) + mm = @inferred stack_ca_int(Val((n_row,)), mvi) testm = (m) -> begin @test mm isa AbstractComponentArrayInterpreter - @test length(mm) == n_row * length(mv) + @test length(mm) == n_row * length(mvi) cm = mm(1:length(mm)) @test cm[2, :c1] == [2, 5] end testm(mm) # f_n_within = (n) -> begin - mm = @inferred stack_ca_int(Val((n,)), mv) + mm = @inferred stack_ca_int(Val((n,)), mvi) end @test_broken @inferred f_n_within(3) # inferred is only f_outer = () -> begin f_n_within_cols = (n) -> begin - mm = @inferred stack_ca_int(mv, Val((n,))) - mm = get_concrete(ComponentArrayInterpreter(mv, (3,))) # same effects + mm = @inferred stack_ca_int(mvi, Val((n,))) + mm = get_concrete(ComponentArrayInterpreter(mvi, (3,))) # same effects end # @inferred f_n_within_cols(3) # inferred is only Any res = f_n_within_cols(3) # inferred is only diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 1ab7bc3..962b5e4 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -17,6 +17,8 @@ using OptimizationOptimisers using MLDataDevices using Suppressor +using Functors + cdev = cpu_device() #scenario = Val((:default,)) @@ -32,27 +34,14 @@ function construct_problem(; scenario::Val{scen}) where scen cor_ends = (P=1:length(θP), M=[length(θM)]) # assume r0 independent of K2 int_θdoubleMM = get_concrete(ComponentArrayInterpreter( flatten1(CA.ComponentVector(; θP, θM)))) - function f_doubleMM(θ::AbstractVector{ET}, x; intθ1) where ET + function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET # extract parameters not depending on order, i.e whether they are in θP or θM - local θc = intθ1(θ) (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par CA.getdata(θc[par])::ET end local y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) return (y) end - f_doubleMM_sites = let intθ1 = int_θdoubleMM, f_doubleMM=f_doubleMM, - θFix = CA.ComponentVector{FT}() - function f_doubleMM_with_global_inner( - θP::AbstractVector{ET}, θMs::AbstractMatrix, xP - ) where ET - #first(eachcol(xP)) - local θMst = θMs' # map_f_each:site requires sites-last format - local pred_sites = map_f_each_site(f_doubleMM, θMst, θP, θFix, xP; intθ1) - local pred_global = eltype(pred_sites)[] - return pred_global, pred_sites - end - end n_out = length(θM) rng = StableRNG(111) # n_batch = 10 @@ -91,12 +80,17 @@ function construct_problem(; scenario::Val{scen}) where scen app, ϕg0 = construct_ChainsApplicator(rng, g_chain) g_chain_scaled = NormalScalingModelApplicator(app, lowers, uppers, FT) #g_chain_scaled = app - ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) + #ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) pbm_covars = (:covarK2 ∈ scen) ? (:K2,) : () - HybridProblem(θP, θM, g_chain_scaled, ϕg0, ϕunc0, - f_doubleMM_sites, f_doubleMM_sites, priors_dict, py, + f_batch = f_sites = PBMSiteApplicator( + f_doubleMM; θP, θM, θFix=CA.ComponentVector{FT}(), + xPvec=xP[:,1]) + HybridProblem(θP, θM, g_chain_scaled, ϕg0, + f_batch, f_sites, priors_dict, py, transM, transP, train_dataloader, n_covar, n_site, n_batch, - cor_ends, pbm_covars) + cor_ends, pbm_covars, + #ϕunc0, + ) end @testset "f_doubleMM from ProbSpec" begin @@ -120,6 +114,7 @@ end # @descend_code_warntype test_f_doubleMM(CA.getdata(θ2), xP1) end +#scenario = Val((:default,)) test_without_flux = (scenario) -> begin #scen = CP._val_value(scenario) gdev = @suppress gpu_device() @@ -160,6 +155,8 @@ test_without_flux = (scenario) -> begin loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; pbm_covars, n_site_batch = n_batch) (_xM, _xP, _y_o, _y_unc, _i_sites) = first(train_loader) + l1 = loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites) + l1 = @inferred ( # @descend_code_warntype ( loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites)) @@ -173,11 +170,13 @@ test_without_flux = (scenario) -> begin optprob = OptimizationProblem(optf, p0, train_loader) res = Optimization.solve( - # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000); - optprob, Adam(0.02), maxiters=1000) - - l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...) - @test isapprox(par_templates.θP, intϕ(res.u).θP, rtol=0.11) + # optprob, Adam(0.02), + callback = callback_loss(100), + optprob, Adam(0.02), epochs = 150); + loss_gf_sites = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; + pbm_covars, n_site_batch = n_site) + l1, y_pred_global, y_pred, θMs_pred = loss_gf_sites(res.u, train_loader.data...) + @test isapprox(par_templates.θP, transP(intϕ(res.u).ϕP), rtol=0.5) end end end @@ -202,8 +201,9 @@ test_with_flux = (scenario) -> begin #callback = callback_loss(100), maxiters = 1200 #maxiters = 1200 #maxiters = 20 - maxiters=200, - gdev = identity, + #maxiters=200, + epochs = 2, + gdevs = (; gdev_M=identity, gdev_P=identity), #gpu_handler = NullGPUDataHandler is_inferred = Val(true), ) @@ -222,9 +222,10 @@ test_with_flux = (scenario) -> begin (; ϕ, θP, resopt) = solve(prob, solver; scenario, rng, #callback = callback_loss(100), maxiters = 1200, #maxiters = 20 # too small so that it yields error - maxiters=37, + #maxiters=37, # still complains "need to specify maxiters or epochs" + epochs = 1, θmean_quant = 0.01, # test constraining mean to initial prediction - gdev = identity, + gdevs = (; gdev_M=identity, gdev_P=identity), is_inferred = Val(true), ) θPt = get_hybridproblem_par_templates(prob; scenario).θP @@ -233,9 +234,14 @@ test_with_flux = (scenario) -> begin θP prob.θP end; +end # test_with flux +test_with_flux(Val((:default,))) +test_with_flux(Val((:covarK2,))) - +#scenario = Val((:default,:useSitePBM)) +test_with_flux_gpu = (scenario) -> begin + # using Problem from DoubleMMCase if gdev isa MLDataDevices.AbstractGPUDevice @testset "HybridPosteriorSolver gpu $(last(CP._val_value(scenario)))" begin scenf = Val((CP._val_value(scenario)..., :use_Flux, :use_gpu, :omit_r0)) @@ -247,17 +253,21 @@ test_with_flux = (scenario) -> begin n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario = scenf) n_batches_in_epoch = n_site ÷ n_batch (; ϕ, θP, resopt) = solve(prob, solver; scenario = scenf, rng, - maxiters = 37, # smallest value by trial and error + #maxiters = 37, # smallest value by trial and error #maxiters = 20 # too small so that it yields error + epochs = 2, θmean_quant = 0.01, # test constraining mean to initial prediction is_inferred = Val(true), - ); + gdevs = (; gdev_M=gpu_device(), gdev_P=identity), + ); @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector #@test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations in test -> may fail # solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, - maxiters = 37, + #maxiters = 37, + epochs = 2, + gdevs = (; gdev_M=gpu_device(), gdev_P=identity), is_inferred = Val(true), ); @test cdev(ϕ.unc.ρsM)[1] > 0 @@ -273,6 +283,7 @@ test_with_flux = (scenario) -> begin n_epoch = 20 # requires (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, maxiters = n_batches_in_epoch * n_epoch, + gdevs = (; gdev_M=gpu_device(), gdev_P=identity), callback = callback_loss(n_batches_in_epoch*5) ); @test cdev(ϕ.unc.ρsM)[1] > 0 @@ -304,15 +315,21 @@ test_with_flux = (scenario) -> begin scenf = Val((CP._val_value(scenario)..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu)) rng = StableRNG(111) probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf); + # put Applicator to gpu (θFix) + probg = HybridProblem( + probg, + f_batch = fmap(gdev, probg.f_batch), + f_allsites = fmap(gdev, probg.f_allsites)) #prob = CP.update(probg, transM = identity, transP = identity); solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) - n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario = scenf) + n_site, n_batch = get_hybridproblem_n_site_and_batch(probg; scenario = scenf) n_batches_in_epoch = n_site ÷ n_batch - (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, rng, - maxiters = 37, # smallest value by trial and error - #maxiters = 20 # too small so that it yields error + (; ϕ, θP, resopt, probo) = solve(probg, solver; scenario = scenf, rng, + #maxiters = 37, # smallest value by trial and error + #maxiters = 20, # too small so that it yields error + epochs = 1, #θmean_quant = 0.01, # TODO make possible on gpu - cdev = identity, # do not move ζ to cpu # TODO infer in solve from scenario + gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device()), is_inferred = Val(true), ); @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector @@ -324,5 +341,7 @@ test_with_flux = (scenario) -> begin end # if gdev isa MLDataDevices.AbstractGPUDevice end # test_with flux -test_with_flux(Val((:default,))) -test_with_flux(Val((:covarK2,))) +test_with_flux_gpu(Val((:default,))) +test_with_flux_gpu(Val((:covarK2,))) +test_with_flux_gpu(Val((:default,:useSitePBM))) + diff --git a/test/test_cholesky_structure.jl b/test/test_cholesky_structure.jl index 882c0d6..2ec4b4f 100644 --- a/test/test_cholesky_structure.jl +++ b/test/test_cholesky_structure.jl @@ -118,20 +118,18 @@ end; @test_throws AssertionError CP.utri2vec_pos(2, 1) end -@testset "vec2uutri gpu" begin - if ggdev isa MLDataDevices.AbstractGPUDevice # only run the test, if CUDA is working (not on Github ci) - v_orig = 1.0f0:1.0f0:6.0f0 - #v = ggdev(v_orig) +function test_vec2uutri_gpu(v_orig, n) v = ggdev(collect(v_orig)) U1v = @inferred CP.vec2uutri(v) @test !(U1v isa UnitUpperTriangular) # on CUDA work with normal matrix @test U1v isa GPUArraysCore.AbstractGPUMatrix - @test size(U1v, 1) == 4 + @test size(U1v, 1) == n @test cdev(U1v) == CP.vec2uutri(v_orig) gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v))), v)[1] # works nice @test gr isa GPUArraysCore.AbstractGPUArray - @test cdev(gr) == (1:6) .* 2.0 + @test cdev(gr) == v_orig .* 2.0 # + # conversion backwards v2 = @inferred CP.uutri2vec(U1v) @test v2 isa GPUArraysCore.AbstractGPUVector @test eltype(v2) == eltype(U1v) @@ -140,6 +138,16 @@ end @test gr isa GPUArraysCore.AbstractGPUArray @test all(diag(gr) .== 0) @test cdev(CP.uutri2vec(gr)) == fill(2.0f0, length(v_orig)) + # + tmp = CP.vec2uutri(v) # now applied to CuVector + #methods(CP.vec2uutri) +end +@testset "vec2uutri gpu" begin + if ggdev isa MLDataDevices.AbstractGPUDevice # only run the test, if CUDA is working (not on Github ci) + v_orig = [1.0f0]; n = 2 # 2x2 matrix with one parameter + test_vec2uutri_gpu(v_orig, n) + v_orig = 1.0f0:1.0f0:6.0f0; n = 4 #4x4 matrix with 1+2+3=6 parameters + test_vec2uutri_gpu(v_orig, n) end end; diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 1d3407f..e87aab8 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -61,7 +61,8 @@ end is = repeat((1:length(θP_true))', n_site) θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true) #xPM = map(xP1s -> repeat(xP1s', n_site), xP[1]) - xPM = (S1 = CA.getdata(xP[:S1, :])', S2 = CA.getdata(xP[:S2, :])') + #xPM = (S1 = CA.getdata(xP[:S1, :])', S2 = CA.getdata(xP[:S2, :])') + xPM = xP #θ = hcat(θP_true[is], θMs_true') intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1]))) #θpos = get_positions(intθ1) @@ -70,25 +71,28 @@ end fy = let is = is, intθ = intθ (θvec, xPM) -> begin θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) - y = CP.DoubleMM.f_doubleMM(θ, xPM; intθ) + θc = intθ(θ) + y = CP.DoubleMM.f_doubleMM_sites(θc, xPM) #y = @inferred CP.DoubleMM.f_doubleMM(θ, xPM, intθ) # @descend_code_warntype CP.DoubleMM.f_doubleMM(θ, xPM, intθ) #y = CP.DoubleMM.f_doubleMM(θ, xPM, θpos) end end y = @inferred fy(θvec, xPM) - y_exp = map_f_each_site(CP.DoubleMM.f_doubleMM, θMs_true, θP_true, - Vector{eltype(θP_true)}(undef, 0), xP; intθ1) - @test y == y_exp' + + f_batch = PBMSiteApplicator(CP.DoubleMM.f_doubleMM; + θP = θP_true, θM = θMs_true[:,1], θFix=CA.ComponentVector(), xPvec=xP[:,1]) + y_exp = f_batch(θP_true, θMs_true', xP)[2] + @test y == y_exp ygrad = Zygote.gradient(θv -> sum(fy(θv, xPM)), θvec)[1] if gdev isa MLDataDevices.AbstractGPUDevice # θg = gdev(θ) # xPMg = gdev(xPM) # yg = CP.DoubleMM.f_doubleMM(θg, xPMg, intθ); θvecg = gdev(θvec); # errors without ";" - xPMg = gdev(xPM) + xPMg = CP.apply_preserve_axes(gdev, xPM) yg = @inferred fy(θvecg, xPMg) - @test cdev(yg) == y_exp' + @test cdev(yg) == y_exp ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1] @test ygradg isa CA.ComponentArray @test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray @@ -101,7 +105,7 @@ end @testset "neg_logden_obs Matrix" begin is = repeat(axes(θP_true, 1)', n_site) θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true) - xPM = (S1 = CA.getdata(xP[:S1, :])', S2 = CA.getdata(xP[:S2, :])') + xPM = xP #(S1 = CA.getdata(xP[:S1, :])', S2 = CA.getdata(xP[:S2, :])') #θ = hcat(θP_true[is], θMs_true') intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1]))) #θpos = get_positions(intθ1) @@ -109,9 +113,10 @@ end fcost = let is = is, intθ = intθ, fneglogden=fneglogden (θvec, xPM, y_o, y_unc) -> begin θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) - y = CP.DoubleMM.f_doubleMM(θ, xPM; intθ) + θc = intθ(θ) + y = CP.DoubleMM.f_doubleMM_sites(θc, xPM) #y = CP.DoubleMM.f_doubleMM(θ, xPM, θpos) - res = fneglogden(y_o, y', y_unc) + res = fneglogden(y_o, y, y_unc) res end end @@ -259,7 +264,8 @@ if gdev isa MLDataDevices.AbstractGPUDevice @testset "transfer NormalScalingModelApplicator to gpu" begin @test g_gpu.μ isa GPUArraysCore.AbstractGPUArray y_gpu = g_gpu(xM_gpu, ϕg0_gpu) + @test y_gpu isa GPUArraysCore.AbstractGPUArray y = g(xM, ϕg0) @test cdev(y_gpu) ≈ y end -end \ No newline at end of file +end diff --git a/test/test_elbo.jl b/test/test_elbo.jl index c5b5a56..7569292 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -138,7 +138,8 @@ test_scenario = (scenario) -> begin UC = CP.transformU_cholesky1(ϕunc_true.ρsM); Σ = UC' * UC @test Σ[1,2] ≈ ρsM_true[1] - probd = CP.update(probc; ϕunc=ϕunc_true); + probd = HybridProblem(probc; ϕunc=ϕunc_true); + _ϕ = vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc) #hcat(ϕ_ini, ϕ, _ϕ)[1:4,:] #hcat(ϕ_ini, ϕ, _ϕ)[(end-20):end,:] @@ -205,7 +206,7 @@ test_scenario = (scenario) -> begin @testset "predict_hvi check sd" begin # test if uncertainty and reshaping is propagated # here inverse the predicted θs and then test distribution - probcu = CP.update(probc, ϕunc=ϕunc_true); + probcu = HybridProblem(probc, ϕunc=ϕunc_true); n_sample_pred = 24_000 (; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probcu; scenario, n_sample_pred); #size(_ζsMs), size(θsMs) @@ -368,34 +369,19 @@ test_scenario = (scenario) -> begin end end - @testset "apply_f $(last(CP._val_value(scenario)))" begin - ζP = ζsP[:,1] - ζMs = ζsMs[:,:,1] - y_pred, θP_pred, θMs_pred = @inferred CP.apply_f_trans( - ζP, ζMs, f, xP[:,1:n_batch]; transP, transM) - @test size(y_pred) == size(y_o[:,1:n_batch]) - @test size(θP_pred) == (n_θP,) - @test size(θMs_pred) == (n_batch, n_θM) - # - ym_pred, θPm_pred, θMsm_pred = CP.apply_f_trans( - ζsP[:,1:1], ζMs[:,:,1:1], f, xP[:,1:n_batch]; transP, transM) - @test ym_pred[:,:,1] == y_pred - @test θPm_pred[:,1] == θP_pred - @test θMsm_pred[:,:,1] == θMs_pred - end - - @testset "predict_hvi cpu $(last(CP._val_value(scenario)))" begin + @testset "sample_posterior apply_process_model cpu $(last(CP._val_value(scenario)))" begin # intm_PMs_gen = get_ca_int_PMs(n_site) # trans_PMs_gen = get_transPMs(n_site) # @test length(intm_PMs_gen) == 402 # @test trans_PMs_gen.length_in == 402 n_sample_pred = 30 - (; y, θsP, θsMs, entropy_ζ) = + (; θsP, θsMs, entropy_ζ) = #Cthulhu.@descend_code_warntype ( @inferred ( - predict_hvi(rng, g, f_pred, ϕ_ini, xM, xP; + sample_posterior(rng, g, ϕ_ini, xM; int_μP_ϕg_unc, int_unc, transP, transM, + cdev = identity, n_sample_pred, cor_ends, pbm_covar_indices) ) @test θsP isa AbstractMatrix @@ -403,6 +389,8 @@ test_scenario = (scenario) -> begin int_mP = ComponentArrayInterpreter(int_P, (size(θsP, 2),)) θsPc = int_mP(θsP) @test all(θsPc[:r0, :] .> 0) + # + y = apply_process_model(θsP, θsMs, f_pred, xP) @test y isa Array @test size(y) == (size(y_o)..., n_sample_pred) end @@ -413,12 +401,13 @@ test_scenario = (scenario) -> begin ϕ_ini_g = ggdev(CA.getdata(ϕ_ini)) xMg = ggdev(xM) n_sample_pred = 30 - (; y, θsP, θsMs, entropy_ζ) = + (; θsP, θsMs, entropy_ζ) = #Cthulhu.@descend_code_warntype ( @inferred ( - predict_hvi(rng, g_gpu, f_pred, ϕ_ini_g, xMg, xP; + sample_posterior(rng, g_gpu, ϕ_ini_g, xMg; int_μP_ϕg_unc, int_unc, transP, transM, + cdev = cpu_device(), n_sample_pred, cor_ends, pbm_covar_indices) ) @test θsP isa AbstractMatrix @@ -426,6 +415,8 @@ test_scenario = (scenario) -> begin int_mP = ComponentArrayInterpreter(int_P, (size(θsP, 2),)) θsPc = int_mP(θsP) @test all(θsPc[:r0, :] .> 0) + # + y = apply_process_model(θsP, θsMs, f_pred, xP) @test y isa Array @test size(y) == (size(y_o)..., n_sample_pred) end