diff --git a/Project.toml b/Project.toml index b892ffc..392eb9d 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -60,6 +61,7 @@ StableRNGs = "1.0.2" StaticArrays = "1.9.13" StatsBase = "0.34.4" StatsFuns = "1.3.2" +Test = "1.10" julia = "1.10" [workspace] diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index d443ee8..6bcf7f1 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -15,16 +15,16 @@ import MLDataDevices, CUDA, cuDNN, GPUArraysCore rng = StableRNG(115) scenario = NTuple{0, Symbol}() -scenario = (:omit_r0,) # without omit_r0 ambiguous K2 estimated to high -scenario = (:use_Flux, :use_gpu) -scenario = (:use_Flux, :use_gpu, :omit_r0, :few_sites) -scenario = (:use_Flux, :use_gpu, :omit_r0, :few_sites, :covarK2) -scenario = (:use_Flux, :use_gpu, :omit_r0, :sites20, :covarK2) -scenario = (:use_Flux, :use_gpu, :omit_r0) -scenario = (:use_Flux, :use_gpu, :omit_r0, :covarK2) +scenario = Val((:omit_r0,)) # without omit_r0 ambiguous K2 estimated to high +scenario = Val((:use_Flux, :use_gpu)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :few_sites)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :few_sites, :covarK2)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :sites20, :covarK2)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0)) +scenario = Val((:use_Flux, :use_gpu, :omit_r0, :covarK2)) # prob = DoubleMM.DoubleMMCase() -gdev = :use_gpu ∈ scenario ? gpu_device() : identity +gdev = :use_gpu ∈ HVI._val_value(scenario) ? gpu_device() : identity cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity #------ setup synthetic data and training data loader @@ -55,10 +55,12 @@ n_epoch = 80 maxiters = n_batches_in_epoch * n_epoch); # update the problem with optimized parameters prob0o = probo; -y_pred_global, y_pred, θMs = gf(prob0o, scenario); -plt = scatterplot(θMs_true[1, :], θMs[1, :]); +y_pred_global, y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true)); +# @descend_code_warntype gf(prob0o; scenario) +#@usingany UnicodePlots +plt = scatterplot(θMs_true'[:, 1], θMs[:, 1]); lineplot!(plt, 0, 1) -scatterplot(θMs_true[2, :], θMs[2, :]) +scatterplot(θMs_true'[:,2], θMs[:,2]) prob0o.θP #scatterplot(vec(y_true), vec(y_o)) #scatterplot(vec(y_true), vec(y_pred)) @@ -93,7 +95,7 @@ end # and fit gf starting from true parameters prob = prob0 g, ϕg0_cpu = get_hybridproblem_MLapplicator(prob; scenario) - ϕg0 = (:use_Flux ∈ scenario) ? gdev(ϕg0_cpu) : ϕg0_cpu + ϕg0 = (:use_Flux ∈ _val_value(scenario)) ? gdev(ϕg0_cpu) : ϕg0_cpu (; transP, transM) = get_hybridproblem_transforms(prob; scenario) function loss_g(ϕg, x, g, transM; gpu_handler = HVI.default_GPU_DataHandler) @@ -160,9 +162,9 @@ n_epoch = 40 # update the problem with optimized parameters, including uncertainty prob1o = probo; n_sample_pred = 400 -#(; θ, y) = predict_gf(rng, prob1o, xM, xP; scenario, n_sample_pred); -(; θ, y) = predict_gf(rng, prob1o; scenario, n_sample_pred); -(θ1, y1) = (θ, y); +#(; θ, y) = predict_hvi(rng, prob1o, xM, xP; scenario, n_sample_pred); +(; y, θsP, θsMs) = predict_hvi(rng, prob1o; scenario, n_sample_pred, is_inferred=Val(true)); +(y1, θsP1, θsMs1) = (y, θsP, θsMs); () -> begin # prediction with fitted parameters (should be smaller than mean) y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario) @@ -192,17 +194,18 @@ prob2o = probo; () -> begin using JLD2 #fname_probos = "intermediate/probos_$(last(scenario)).jld2" - fname_probos = "intermediate/probos800_$(last(scenario)).jld2" + fname_probos = "intermediate/probos800_$(last(HVI._val_value(scenario))).jld2" JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o)) tmp = JLD2.load(fname_probos) - # get_train_loader function could not be restored with JLD2 + # TODO replace function closure by Callable to store + # closure function could not be restored with JLD2 prob1o = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader); prob2o = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader); end () -> begin # load the non-covar scenario using JLD2 - #fname_probos = "intermediate/probos_$(last(scenario)).jld2" + #fname_probos = "intermediate/probos_$(last(_val_value(scenario))).jld2" fname_probos = "intermediate/probos800_omit_r0.jld2" tmp = JLD2.load(fname_probos) # get_train_loader function could not be restored with JLD2 @@ -210,7 +213,7 @@ end prob2o_indep = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader); # test predicting correct obs-uncertainty of predictive posterior n_sample_pred = 400 - (; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred); + (; θ, y, entropy_ζ) = predict_hvi(rng, prob2o_indep, xM, xP; scenario, n_sample_pred); (θ2_indep, y2_indep) = (θ, y) #(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed end @@ -241,15 +244,16 @@ end #ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu ϕunc_VI = interpreters.unc(ζ_VIc.unc) ϕunc_VI.ρsM -exp.(ϕunc_VI.logσ2_logP) -exp.(ϕunc_VI.coef_logσ2_logMs[1, :]) +exp.(ϕunc_VI.logσ2_ζP) +exp.(ϕunc_VI.coef_logσ2_ζMs[1, :]) # test predicting correct obs-uncertainty of predictive posterior n_sample_pred = 400 -(; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred); -(θ2, y2) = (θ, y) +(; y, θsP, θsMs) = predict_hvi(rng, prob2o; scenario, n_sample_pred); +(y2, θsP2, θsMs2) = (y, θsP, θsMs); + size(y) # n_obs x n_site, n_sample_pred -size(θ) # n_θP + n_site * n_θM x n_sample +size(θsMs) # n_site x n_θM x n_sample σ_o_post = dropdims(std(y; dims = 3), dims = 3); σ_o = exp.(y_unc[:, 1] / 2) @@ -264,11 +268,11 @@ plt = scatterplot(vec(y_true), vec(mean_y_pred)); lineplot!(plt, 0, 2) mean(mean_y_pred - y_true) # still ok -mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) -plt = scatterplot(θMs_true[1, :], mean_θ.Ms[1, :]); +mean_θMs = CA.ComponentArray(mean(θsMs; dims = 3)[:,:,1], CA.getaxes(θMs_true')) +plt = scatterplot(θMs_true'[:,1], mean_θMs[:,1]); lineplot!(plt, 0, 1) -plt = scatterplot(θMs_true[2, :], mean_θ.Ms[2, :]) -histogram(θ[:P,:]) +plt = scatterplot(θMs_true'[:,2], mean_θMs[:,2]) +histogram(θsP) #scatter(fig[1,1], CA.getdata(θMs_true[1, :]), CA.getdata(mean_θ.Ms[1, :])); ablines!(fig[1,1], 0, 1) #@usingany AlgebraOfGraphices #fig = Figure() @@ -320,10 +324,10 @@ end () -> begin # look at distribution of parameters, predictions, and likelihood and elob at one site function predict_site(probo, i_site) - (; θ, y, entropy_ζ) = predict_gf(rng, probo, xM, xP; scenario, n_sample_pred) + (; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probo; scenario, n_sample_pred) y_site = y[:, i_site, :] - θMs_i = map(i_rep -> θ[:Ms, i_rep][:, i_site], axes(θ, 2)) - r1s = map(x -> x[1], θMs_i) + θMs_i = CA.ComponentArray(θsMs[i_site,:,:], (CA.getaxes(θMs_true)[1], CA.FlatAxis())) + r1s = θMs_i[:r1,:] # K1s = map(x -> x[2], θMs_i) # invt = map(Bijectors.inverse, get_hybridproblem_transforms(probo; scenario)) # θPs = θ[:P,:] @@ -348,6 +352,7 @@ end #@usingany CairoMakie #@usingany AlgebraOfGraphics + #@usingany DataFrames const aog = AlgebraOfGraphics # especially uncertainty is put to r1 (compensated by larger K1) @@ -466,12 +471,15 @@ cor_ends = get_hybridproblem_cor_ends(prob; scenario) g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) ϕunc0 = get_hybridproblem_ϕunc(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) +hpints = HybridProblemInterpreters(prob; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, n_site; transP, transM, ϕunc0); + θP, θM, cor_ends, ϕg0, hpints; transP, transM, ϕunc0); -intm_PMs_gen = get_ca_int_PMs(n_site); +intm_PMs_gen = get_int_PMs_site(hpints); #intm_PMs_gen = get_ca_int_PMs(100); -trans_PMs_gen = get_transPMs(n_site); +#trans_PMs_gen = get_transPMs(n_site); +trans_Ms_gen = StackedArray(transM, n_site) + #trans_PMs_gen = get_transPMs(100); """ @@ -483,9 +491,9 @@ transposeMs = (ζ, intm_PMs, back=false) -> begin Ms = back ? reshape(ζc.Ms, reverse(size(ζc.Ms))) : ζc.Ms ζct = vcat(CA.getdata(ζc.P), vec(CA.getdata(Ms)')) end -θ_true = vcat(CA.getdata(θP_true), vec(CA.getdata(θMs_true))); -ζ_true = log.(θ_true); -θ0_true = transposeMs(θ_true, intm_PMs_gen); +# θ_true = vcat(CA.getdata(θP_true), vec(CA.getdata(θMs_true))); +# ζ_true = log.(θ_true); +θ0_true = vcat(CA.getdata(θP_true), vec(CA.getdata(θMs_true'))); # note the transpose ζ0_true = log.(θ0_true); #transposeMs(θ0_true, intm_PMs_gen, true) == θ_true @@ -504,16 +512,24 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread () -> begin using JLD2 - fname = "intermediate/doubleMM_chain_zeta_$(last(scenario)).jld2" + fname = "intermediate/doubleMM_chain_zeta_$(last(HVI._val_value(scenario))).jld2" jldsave(fname, false, IOStream; chain) chain = load(fname, "chain"; iotype = IOStream); + n_sample_NUTS = size(Array(chain),1) end #ζi = first(eachrow(Array(chain))) f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true) -ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain))); -(; θ, y) = HVI.predict_ζf(ζs, f_allsites, xP, trans_PMs_gen, intm_PMs_gen); -(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y); +#ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain))); +ζsP = Array(chain)[:,1:n_θP]' +ζsMst = reshape(Array(chain)[:,(n_θP+1) : end], n_sample_NUTS, n_site, n_θM) +ζsMs = permutedims(ζsMst, (2,3,1)) +# need to reshape according to generate_ζ +ζsMs[:,:,1] # first sample: n_site x n_par +ζsMs[:,1,:] # first parameter n_site x n_sample + +(; y, θsP, θsMs) = HVI.apply_f_trans(ζsP, ζsMs, f_allsites, xP; transP, transM); +(y_hmc, θsP_hmc, θsMs_hmc) = (; y, θsP, θsMs); () -> begin # check that the model predicts the same as HVI-code @@ -522,7 +538,7 @@ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true) #chain0 = Chains(transposeMs(ζ_true, intm_PMs_gen)', _parnames); chain0 = Chains(ζ0_true', _parnames); # transpose here only for chain array y0inv = generated_quantities(model, chain0)[1, 1] - y0pred = HVI.predict_y(ζ_true, xP, f, trans_PMs_gen, intm_PMs_gen)[2] + y0pred = HVI.apply_f_trans(ζ_true, xP, f, trans_PMs_gen, intm_PMs_gen)[2] y0pred .- y_true y0inv .- y_true end @@ -548,15 +564,16 @@ y_inv11 = y[1, 1, :] histogram(y_inv11 .- y_o[1, 1]) histogram(y_inv11 .- y_true[1, 1]) -histogram(ζs[1, :]) -describe(pdf.(Ref(prior_ζ), ζs[1, :])) # only small differences -pdf(prior_ζ, log(θ_true[1])) +histogram(ζsP[1, :]) +describe(pdf.(Ref(prior_ζ), ζsP[1, :])) # only small differences +pdf(prior_ζ, log(θP_true[1])) -mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) -histogram(θ[:P, :] .- θP_true) # all overestimated ? -plt = scatterplot(θMs_true[1, :], mean_θ.Ms[1, :]); +mean_θP = CA.ComponentArray(mean(θsP; dims = 2)[:, 1], CA.getaxes(θP_true)) +mean_θMs = CA.ComponentArray(mean(θsMs; dims = 3)[:,:, 1], CA.getaxes(θMs_true')) +histogram(θsP .- θP_true) # all overestimated ? +plt = scatterplot(θMs_true'[:,1], mean_θMs[:,1]); lineplot!(plt, 0, 1) -plt = scatterplot(θMs_true[2, :], mean_θ.Ms[2, :]); +plt = scatterplot(θMs_true'[:,2], mean_θMs[:,2]); lineplot!(plt, 0, 1) #------------------ compare HVI vs HMC sample @@ -570,28 +587,35 @@ lineplot!(plt, 0, 1) 72 .* (x_inch, y_inch) ./ cfg.pt_per_unit # size_pt end - ζs_hvi = log.(θ2) - ζs_hvi_indep = log.(θ2_indep) - int_pms = interpreters.PMs - par_pos = int_pms(1:length(int_pms)) + ζsP_hvi = log.(θsP2) + ζsP_hvi_indep = log.(θsP2) # TODO rerun and reload replace θsP2 + ζsP_hmc = log.(θsP_hmc) + ζsMs_hvi = log.(θsMs2) + ζsMs_hvi_indep = log.(θsMs2) # TODO rerun and reload replace θsMs2 + ζsMs_hmc = log.(θsMs_hmc) + # int_pms = interpreters.PMs + # par_pos = int_pms(1:length(int_pms)) i_sites = 1:10 #i_sites = 6:10 #i_sites = 11:15 - scen = vcat(fill(:hvi,size(ζs_hvi,2)),fill(:hmc,size(ζs_hmc,2)),fill(:hvi_indep,size(ζs_hvi_indep,2))) - dfP = mapreduce(vcat, axes(θP,1)) do i_par - pos = par_pos.P[i_par] + scen = vcat(fill(:hvi,size(ζsP_hvi,2)),fill(:hmc,size(ζsP_hmc,2)),fill(:hvi_indep,size(ζsP_hvi_indep,2))) + dfP = mapreduce(vcat, axes(θP_true,1)) do i_par + #pos = par_pos.P[i_par] DataFrame( - value = vcat(ζs_hvi[pos,:], ζs_hmc[pos,:], ζs_hvi_indep[pos,:]), - variable = keys(θP)[i_par], + value = vcat(ζsP_hvi[i_par, :], ζsP_hmc[i_par,:], ζsP_hvi_indep[i_par,:]), + variable = keys(θP_true)[i_par], site = i_sites[1], Method = scen ) end dfMs = mapreduce(vcat, i_sites) do i_site mapreduce(vcat, axes(θM,1)) do i_par - pos = par_pos.Ms[i_par, i_site] + #pos = par_pos.Ms[i_par, i_site] DataFrame( - value = vcat(ζs_hvi[pos,:], ζs_hmc[pos,:], ζs_hvi_indep[pos,:]), + value = vcat( + ζsMs_hvi[i_site,i_par,:], + ζsMs_hmc[i_site,i_par,:], + ζsMs_hvi_indep[i_site,i_par,:]), variable = keys(θM)[i_par], site = i_site, Method = scen @@ -627,7 +651,7 @@ lineplot!(plt, 0, 1) fg = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,)); fig save("tmp.svg", fig) - save_with_config("intermediate/compare_hmc_hvi_sites_$(last(scenario))", fig; makie_config) + save_with_config("intermediate/compare_hmc_hvi_sites_$(last(HVI._val_value(scenario)))", fig; makie_config) plt = (data(subset(df, :Method => ByRow(∈((:hvi, :hvi_indep))))) * mapping(:value=> (x -> x ) => "", color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) + data(df_true) * mapping(:value) * visual(VLines; color=:blue, linestyle=:dash)) * @@ -636,7 +660,7 @@ lineplot!(plt, 0, 1) fg = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,)); fig save("tmp.svg", fig) - save_with_config("intermediate/compare_hvi_indep_sites_$(last(scenario))", fig; makie_config) + save_with_config("intermediate/compare_hvi_indep_sites_$(last(HVI._val_value(scenario)))", fig; makie_config) # # compare density of predictions @@ -695,6 +719,7 @@ lineplot!(plt, 0, 1) axis=(xlabelvisible=false,yticklabelsvisible=false)); legend!(fig[1,3], f, ; tellwidth=false, halign=:right, valign=:top) # , margin=(-10, -10, 10, 10) fig + save("tmp.svg", fig) save_with_config("intermediate/compare_hmc_hvi_sites_y_$(last(scenario))", fig; makie_config) # hvi predicts y better, hmc fails for quite a few obs: 3,5,6 @@ -718,177 +743,115 @@ end end -#---- do an DEMC inversion of the PBM model with parameters at constrained scale -# construct a Normal prior that ranges roughly across 1e-2 to 10 -# prior_θ = fit(LogNormal, @qp_ll(1e-2), @qp_uu(10)) -# prior_θn = (n) -> MvLogNormal(fill(prior_θ.μ, n), PDiagMat(fill(abs2(prior_θ.σ), n))) -prior_θ = Normal(0, 10) -prior_θn = (n) -> MvNormal(fill(prior_θ.μ, n), PDiagMat(fill(abs2(prior_θ.σ), n))) -prior_θn(3) -prob = HVI.update(prob0o); - -(; θM, θP) = get_hybridproblem_par_templates(prob; scenario) -n_θM, n_θP = length.((θM, θP)) -f = get_hybridproblem_PBmodel(prob; scenario) -@model function fsites_uc( - y, ::Type{T} = Float64; f, n_θP, n_θM, σ_o, n_obs = length(σ_o)) where {T} - n_obs, n_site = size(y) - prior_θP = prior_θn(n_θP) - prior_θM_sites = fill(prior_θn(n_site), n_θM) - θP ~ prior_θP #MvNormal(n_θP, 10.0) - # CAUTION: order of vectorizing matrix depends on order of ~ - # need to assign each variable in first site first, then second site, ... - # need to construct different MvNormal prior if std differs by variable - # or need to take care when extracting samples or specifying initial conditions - θMs = Matrix{T}(undef, n_θM, n_site) - # the first loop vectorizes θMs by columns but is much slower - # for i_site in 1:n_site - # ζMs[:, i_site] ~ prior_ζn(n_θM) #MvNormal(n_site, 10.0) - # end - # this loop is faster, but vectorizes θMs by rows in parameter vector - for i_par in 1:n_θM - θMs[i_par, :] ~ prior_θM_sites[i_par] +() -> begin # depr---- do an DEMC inversion of the PBM model with parameters at constrained scale + # construct a Normal prior that ranges roughly across 1e-2 to 10 + # prior_θ = fit(LogNormal, @qp_ll(1e-2), @qp_uu(10)) + # prior_θn = (n) -> MvLogNormal(fill(prior_θ.μ, n), PDiagMat(fill(abs2(prior_θ.σ), n))) + prior_θ = Normal(0, 10) + prior_θn = (n) -> MvNormal(fill(prior_θ.μ, n), PDiagMat(fill(abs2(prior_θ.σ), n))) + prior_θn(3) + prob = HVI.update(prob0o); + + (; θM, θP) = get_hybridproblem_par_templates(prob; scenario) + n_θM, n_θP = length.((θM, θP)) + f = get_hybridproblem_PBmodel(prob; scenario) + + @model function fsites_uc( + y, ::Type{T} = Float64; f, n_θP, n_θM, σ_o, n_obs = length(σ_o)) where {T} + n_obs, n_site = size(y) + prior_θP = prior_θn(n_θP) + prior_θM_sites = fill(prior_θn(n_site), n_θM) + θP ~ prior_θP #MvNormal(n_θP, 10.0) + # CAUTION: order of vectorizing matrix depends on order of ~ + # need to assign each variable in first site first, then second site, ... + # need to construct different MvNormal prior if std differs by variable + # or need to take care when extracting samples or specifying initial conditions + θMs = Matrix{T}(undef, n_θM, n_site) + # the first loop vectorizes θMs by columns but is much slower + # for i_site in 1:n_site + # ζMs[:, i_site] ~ prior_ζn(n_θM) #MvNormal(n_site, 10.0) + # end + # this loop is faster, but vectorizes θMs by rows in parameter vector + for i_par in 1:n_θM + θMs[i_par, :] ~ prior_θM_sites[i_par] + end + # this fills in rows first, but is also slower- why? + #ζMs[:] ~ prior_ζn(n_θM * n_site) + # assume σ_o known, see f_MM + #σ_o ~ truncated(Normal(0, 1); lower=0) + y_pred = f(θP, θMs, xP)[2] # first is global return + #i_obs = 1 + for i_obs in 1:n_obs + #pdf(MvNormal(y_pred[i_obs,:], σ_o[i_obs]),y[i_obs,:]) + y[i_obs, :] ~ MvNormal(y_pred[i_obs, :], σ_o[i_obs]) # single value σ instead of variance + end + #Main.@infiltrate_main # step to second time + # θMs_MCc[:,:,1] # checking row- or column-order of θMs + # exp.(ζMs) + y_pred end - # this fills in rows first, but is also slower- why? - #ζMs[:] ~ prior_ζn(n_θM * n_site) - # assume σ_o known, see f_MM - #σ_o ~ truncated(Normal(0, 1); lower=0) - y_pred = f(θP, θMs, xP)[2] # first is global return - #i_obs = 1 - for i_obs in 1:n_obs - #pdf(MvNormal(y_pred[i_obs,:], σ_o[i_obs]),y[i_obs,:]) - y[i_obs, :] ~ MvNormal(y_pred[i_obs, :], σ_o[i_obs]) # single value σ instead of variance + model_uc = fsites_uc(y_o; f, n_θP, n_θM, σ_o) + + () -> begin # check that the model predicts the same as HVI-code + _parnames = Symbol.(vcat([ + "θP[1]"], ["θMs[1, :][$i]" for i in 1:n_site], ["θMs[2, :][$i]" for i in 1:n_site])) + #chain0 = Chains(transposeMs(ζ_true, intm_PMs_gen)', _parnames); + chain0 = Chains(θ0_true', _parnames); # transpose here only for chain array + y0inv = generated_quantities(model_uc, chain0)[1, 1] + y0pred = f(θP_true, θMs_true, xP)[2] + y0pred .- y_true + y0inv .- y_true end - #Main.@infiltrate_main # step to second time - # θMs_MCc[:,:,1] # checking row- or column-order of θMs - # exp.(ζMs) - y_pred -end -model_uc = fsites_uc(y_o; f, n_θP, n_θM, σ_o) -() -> begin # check that the model predicts the same as HVI-code - _parnames = Symbol.(vcat([ - "θP[1]"], ["θMs[1, :][$i]" for i in 1:n_site], ["θMs[2, :][$i]" for i in 1:n_site])) - #chain0 = Chains(transposeMs(ζ_true, intm_PMs_gen)', _parnames); - chain0 = Chains(θ0_true', _parnames); # transpose here only for chain array - y0inv = generated_quantities(model_uc, chain0)[1, 1] - y0pred = f(θP_true, θMs_true, xP)[2] - y0pred .- y_true - y0inv .- y_true -end - - -#θ0_true from above, tools from above - -# takes long -n_sample_NUTS = 800 -#n_sample_NUTS = 24 -n_threads = 8 -chain = sample(model_uc, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_threads), - n_threads, initial_params = fill(θ0_true .+ 0.001, n_threads)) - -() -> begin - using JLD2 - jldsave("intermediate/doubleMM_chain_theta.jld2", false, IOStream; chain) - chain = load("intermediate/doubleMM_chain_theta.jld2", "chain"; iotype = IOStream) - # plot chain as above -end - -#θi = first(eachrow(Array(chain))) -θs = mapreduce(θi -> transposeMs(θi, intm_PMs_gen, true), hcat, eachrow(Array(chain))); -(; θ, y) = HVI.predict_ζf(θs, f, xP, Stacked(elementwise(identity)), intm_PMs_gen); - - -mean_y_invθ = map(mean, eachslice(y; dims = (1, 2))); -#describe(mean_y_pred - y_o) -histogram(vec(mean_y_invθ) - vec(y_true)) # predictions centered around y_o (or y_true) -plt = scatterplot(vec(y_true), vec(mean_y_invθ)); -lineplot!(plt, 0, 2) -mean(mean_y_invθ - y_true) # still ok - -# first site, first prediction -y_inv11 = y[1, 1, :] -histogram(y_inv11 .- y_o[1, 1]) -histogram(y_inv11 .- y_true[1, 1]) - -histogram(θs[1, :]) -describe(pdf.(Ref(prior_θ), θs[1, :])) # only small differences -pdf(prior_θ, log(θ_true[1])) - -mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) -histogram(θ[:P, :] .- θP_true) # all overestimated ? -plt = scatterplot(θMs_true[1, :], mean_θ.Ms[1, :]); -lineplot!(plt, 0, 1) -plt = scatterplot(θMs_true[2, :], mean_θ.Ms[2, :]); -lineplot!(plt, 0, 1) - - - - - - - - - - -#---- depr? -size(chain) -θc = Array(chain)' -θinv = CA.ComponentArray(θc, (CA.getaxes(θ[:, 1])[1], CA.Axis(i = 1:size(θc, 2)))) -mean_θinv = CA.ComponentVector( - mean(CA.getdata(θinv); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) - -@assert chain[:, 1, :1] == CA.getdata(θinv[:P, :][:K2, :]) -θP_true -plot = histogram(CA.getdata(θinv[:P, :][:K2, :])) - -plt = scatterplot(θMs_true[1, :], mean_θinv.Ms[1, :]); -lineplot!(plt, 0, 1); -plt = scatterplot(θMs_true[2, :], mean_θinv.Ms[2, :]) + + #θ0_true from above, tools from above + + # takes long + n_sample_NUTS = 800 + #n_sample_NUTS = 24 + n_threads = 8 + chain = sample(model_uc, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_threads), + n_threads, initial_params = fill(θ0_true .+ 0.001, n_threads)) + + () -> begin + using JLD2 + jldsave("intermediate/doubleMM_chain_theta.jld2", false, IOStream; chain) + chain = load("intermediate/doubleMM_chain_theta.jld2", "chain"; iotype = IOStream) + # plot chain as above + end -y_true = f(θP_true, θMs_true, xP)[2] -yinv = map(i -> f(θinv[:P, i], θinv[:Ms, i], xP)[2], axes(θinv, 2)) |> stack -histogram(yinv[1, 1, :]) -y_true[1, 1] + #θi = first(eachrow(Array(chain))) + θs = mapreduce(θi -> transposeMs(θi, intm_PMs_gen, true), hcat, eachrow(Array(chain))); + (; θ, y) = HVI.apply_f_trans(θs, f, xP, Stacked(elementwise(identity)), intm_PMs_gen); -tmp = generated_quantities(model_uc, chain[1:10, :, :]) + mean_y_invθ = map(mean, eachslice(y; dims = (1, 2))); + #describe(mean_y_pred - y_o) + histogram(vec(mean_y_invθ) - vec(y_true)) # predictions centered around y_o (or y_true) + plt = scatterplot(vec(y_true), vec(mean_y_invθ)); + lineplot!(plt, 0, 2) + mean(mean_y_invθ - y_true) # still ok -# reshape θMs (site x par) -> (par x site) -_intm_PMs = ComponentArrayInterpreter( - CA.ComponentVector(P = θP_true, Ms = vec(CA.getdata(θMs_true))), (n_sample_NUTS,)) -extract_parameters_fsites = (chain) -> begin - Ac = _intm_PMs(transpose(Array(chain))) - #θM = Ac[:Ms,:][:,1] - θMs = mapslices(CA.getdata(Ac[:Ms, :]), dims = 1) do θM - # (site x par) -> (par x site) - vec(reshape(θM, n_site, :)') - end - vcat(Ac[:P, :], θMs) -end + # first site, first prediction + y_inv11 = y[1, 1, :] + histogram(y_inv11 .- y_o[1, 1]) + histogram(y_inv11 .- y_true[1, 1]) -#ζs_MC = transpose(max.(-10.0,Array(chain))) -#ζs_MC = transpose(Array(chain)) -ζs_MC = extract_parameters_fsites(chain) -θs_MC = exp.(ζs_MC) + histogram(θs[1, :]) + describe(pdf.(Ref(prior_θ), θs[1, :])) # only small differences + pdf(prior_θ, log(θ_true[1])) -y_pred = y_pred_gen = stack(generated_quantities(model, chain)[:, 1]) + mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) + histogram(θ[:P, :] .- θP_true) # all overestimated ? + plt = scatterplot(θMs_true[1, :], mean_θ.Ms[1, :]); + lineplot!(plt, 0, 1) + plt = scatterplot(θMs_true[2, :], mean_θ.Ms[2, :]); + lineplot!(plt, 0, 1) +end # depr DEMC constrained scale -#ax_θPMs = _get_ComponentArrayInterpreter_axes(int_θPMs) -#intm_PMs = ComponentArrayInterpreter(ax_θPMs, n_sample_NUTS) -intm_PMs = ComponentArrayInterpreter( - CA.ComponentVector(P = 1:n_θP, Ms = 1:(n_θM * n_site_batch)), (n_sample_NUTS,)) -intm_Ps = ComponentArrayInterpreter(θP_true, (n_sample_NUTS,)) -intm_Ms = ComponentArrayInterpreter(θM_true, (n_site_batch, n_sample_NUTS)) -θs_MCc = intm_PMs(θs_MC) -θMs_MCc = intm_Ms(θs_MCc[:Ms, :]) -θPs_MCc = intm_Ps(θs_MCc[:P, :]) -ζs_MCc = intm_PMs(ζs_MC) -ζMs_MCc = intm_Ms(ζs_MCc[:Ms, :]) -ζP_MCc = intm_Ps(ζs_MCc[:P, :]) -# inspect correlation between physical parameter K and ML-parameter r at first (or ith) site +# TODO inspect correlation between physical parameter K and ML-parameter r at first (or ith) site mean_ζP_MC = mapslices(mean, CA.getdata(ζP_MCc), dims = 2)[:, 1] var_ζP_MC = map(x -> var(x; corrected = false), eachrow(ζP_MCc)) @@ -911,7 +874,7 @@ y_pred = stack(map(eachcol(θs_MC)) do θ θc = int_θPMs(θ) #θP, θMs = @view(θ[1:n_θP]), reshape(@view(θ[n_θP+1:end, :]), n_θM, :) θP, θMs = θc.θP, θc.θMs - y_pred_i = applyf(f_doubleMM, θMs, θP) + y_pred_i = map_f_each_site(f_doubleMM, θMs, θP) end) #hcat(y_pred[:,1,1], y_pred_gen[:,1,1]) diff --git a/ext/HybridVariationalInferenceCUDAExt.jl b/ext/HybridVariationalInferenceCUDAExt.jl index 393aea9..aa1e2f2 100644 --- a/ext/HybridVariationalInferenceCUDAExt.jl +++ b/ext/HybridVariationalInferenceCUDAExt.jl @@ -77,10 +77,11 @@ function uutri2vec_gpu!(v::Union{CUDA.CuVector,CUDA.CuDeviceVector}, X::Abstract return nothing # important end -function HVI._create_random(rng, ::CUDA.CuVector{T}, dims...) where {T} +function HVI._create_randn(rng, v::CUDA.CuVector{T,M}, dims...) where {T,M} # ignores rng # https://discourse.julialang.org/t/help-using-cuda-zygote-and-random-numbers/123458/4?u=bgctw - ChainRulesCore.@ignore_derivatives CUDA.randn(dims...) + res = ChainRulesCore.@ignore_derivatives CUDA.randn(dims...) + res::CUDA.CuArray{T, length(dims),M} end diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index cbb2cf8..cda7157 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -18,7 +18,8 @@ end function HVI.apply_model(app::FluxApplicator, x, ϕ) m = app.rebuild(ϕ) - m(x) + res = m(x) + res end # struct FluxGPUDataHandler <: AbstractGPUDataHandler end @@ -38,7 +39,7 @@ end function HVI.construct_3layer_MLApplicator( rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:Flux}; - scenario::NTuple = ()) + scenario::Val{scen}) where scen (;θM) = get_hybridproblem_par_templates(prob; scenario) n_out = length(θM) n_covar = get_hybridproblem_n_covar(prob; scenario) @@ -46,7 +47,7 @@ function HVI.construct_3layer_MLApplicator( n_input = n_covar + n_pbm_covars #(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario) float_type = get_hybridproblem_float_type(prob; scenario) - is_using_dropout = :use_dropout ∈ scenario + is_using_dropout = :use_dropout ∈ scen is_using_dropout && error("dropout scenario not supported with Flux yet.") g_chain = Flux.Chain( # dense layer with bias that maps to 8 outputs and applies `tanh` activation diff --git a/ext/HybridVariationalInferenceSimpleChainsExt.jl b/ext/HybridVariationalInferenceSimpleChainsExt.jl index f4ffb9e..b53262d 100644 --- a/ext/HybridVariationalInferenceSimpleChainsExt.jl +++ b/ext/HybridVariationalInferenceSimpleChainsExt.jl @@ -19,14 +19,14 @@ HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ) function HVI.construct_3layer_MLApplicator( rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains}; - scenario::NTuple = ()) + scenario::Val{scen}) where scen n_covar = get_hybridproblem_n_covar(prob; scenario) n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario)) n_input = n_covar + n_pbm_covars FloatType = get_hybridproblem_float_type(prob; scenario) (;θM) = get_hybridproblem_par_templates(prob; scenario) n_out = length(θM) - is_using_dropout = :use_dropout ∈ scenario + is_using_dropout = :use_dropout ∈ scen g_chain = if is_using_dropout SimpleChain( static(n_input), # input dimension (optional) diff --git a/src/AbstractHybridProblem.jl b/src/AbstractHybridProblem.jl index 645cc72..4705c4a 100644 --- a/src/AbstractHybridProblem.jl +++ b/src/AbstractHybridProblem.jl @@ -40,7 +40,8 @@ returns a Tuple of """ function get_hybridproblem_MLapplicator end -function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario = ()) +function get_hybridproblem_MLapplicator( + prob::AbstractHybridProblem; scenario::Val{scen} = Val(())) where scen get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario) end @@ -202,13 +203,13 @@ end Put relevant parts of the DataLoader to gpu, depending on scenario. """ function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader; - scenario = (), + scenario::Val{scen} = Val(()), gdev = gpu_device(), - gdev_M = :use_gpu ∈ scenario ? gdev : identity, - gdev_P = :f_on_gpu ∈ scenario ? gdev : identity, + gdev_M = :use_gpu ∈ _val_value(scenario) ? gdev : identity, + gdev_P = :f_on_gpu ∈ _val_value(scenario) ? gdev : identity, batchsize = dataloader.batchsize, partial = dataloader.partial - ) + ) where scen xM, xP, y_o, y_unc, i_sites = dataloader.data xM_dev = gdev_M(xM) xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc)) diff --git a/src/ComponentArrayInterpreter.jl b/src/ComponentArrayInterpreter.jl index fe5bb2c..7f15d6c 100644 --- a/src/ComponentArrayInterpreter.jl +++ b/src/ComponentArrayInterpreter.jl @@ -21,12 +21,13 @@ Returns a ComponentArray with underlying data `v`. """ function as_ca end -function Base.length(cai::AbstractComponentArrayInterpreter) +function Base.length(cai::AbstractComponentArrayInterpreter) prod(_axis_length.(CA.getaxes(cai))) end - -(interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray) = as_ca(v, interpreter) +function (interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray{ET}) where ET + as_ca(v, interpreter)::CA.ComponentArray{ET} +end """ Concrete version of `AbstractComponentArrayInterpreter` that stores an axis @@ -39,11 +40,35 @@ Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to performance-critical functions. """ struct StaticComponentArrayInterpreter{AX} <: AbstractComponentArrayInterpreter end -function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {AX} +function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {AX} vr = reshape(v, _axis_length.(AX)) - CA.ComponentArray(vr, AX) + CA.ComponentArray(vr, AX)::CA.ComponentArray{eltype(v)} end +function StaticComponentArrayInterpreter(component_shapes::NamedTuple) + axs = map(component_shapes) do valx + x = _val_value(valx) + ax = x isa Integer ? CA.Shaped1DAxis((x,)) : CA.ShapedAxis(x) + (ax,) + end + axc = compose_axes(axs) + StaticComponentArrayInterpreter{(axc,)}() +end +function StaticComponentArrayInterpreter(ca::CA.ComponentArray) + ax = CA.getaxes(ca) + StaticComponentArrayInterpreter{ax}() +end + +# concatenate from several other ArrayInterpreters, keep static +# did not manage to get it inferred, better use get_concrete(ComponentArrayInterpreter) +# also does not save allocations +# function StaticComponentArrayInterpreter(; kwargs...) +# ints = values(kwargs) +# axc = compose_axes(ints) +# intc = StaticComponentArrayInterpreter{(axc,)}() +# return(intc) +# end + # function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX} # #sum(length, typeof(AX).parameters[1]) # prod(_axis_length.(AX)) @@ -55,7 +80,6 @@ end get_concrete(cai::StaticComponentArrayInterpreter) = cai - """ Non-Concrete version of `AbstractComponentArrayInterpreter` that avoids storing additional type parameters. @@ -66,23 +90,21 @@ not allow compiler-inferred `length` to construct StaticArrays. Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to performance-critical functions. """ -struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter +struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter axes::Tuple #{T, <:CA.AbstractAxis} end -function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter) - vr = reshape(v, _axis_length.(cai.axes)) - CA.ComponentArray(vr, cai.axes) +function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter) + vr = reshape(CA.getdata(v), _axis_length.(cai.axes)) + CA.ComponentArray(vr, cai.axes)::CA.ComponentArray{eltype(v)} end -function CA.getaxes(cai::ComponentArrayInterpreter) +function CA.getaxes(cai::ComponentArrayInterpreter) cai.axes end - get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{cai.axes}() - """ ComponentArrayInterpreter(; kwargs...) ComponentArrayInterpreter(::AbstractComponentArray) @@ -108,71 +130,116 @@ The other constructors allow constructing arrays with additional dimensions. """ function ComponentArrayInterpreter(; kwargs...) ComponentArrayInterpreter(values(kwargs)) -end, +end function ComponentArrayInterpreter(component_shapes::NamedTuple) - component_counts = map(prod, component_shapes) - n = sum(component_counts) - x = 1:n - is_end = cumsum(component_counts) - is_start = (0, is_end[1:(end-1)]...) .+ 1 - #g = (x[i_start:i_end] for (i_start, i_end) in zip(is_start, is_end)) - g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes)) - xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...) - ComponentArrayInterpreter(xc) + #component_counts = map(prod, component_shapes) + # avoid constructing a template first, but create axes + # n = sum(component_counts) + # x = 1:n + # is_end = cumsum(component_counts) + # #is_start = (0, is_end[1:(end-1)]...) .+ 1 # problems with Zygote + # is_start = Iterators.flatten((1:1, is_end[1:(end-1)] .+ 1)) + # g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes)) + # xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...) + # #nt = NamedTuple{propertynames(component_counts)}(g) + # ComponentArrayInterpreter(xc) + axs = map(x -> (x isa Integer ? CA.Shaped1DAxis((x,)) : CA.ShapedAxis(x),), component_shapes) + ax = compose_axes(axs) + m1 = ComponentArrayInterpreter((ax,)) end function ComponentArrayInterpreter(vc::CA.AbstractComponentArray) ComponentArrayInterpreter(CA.getaxes(vc)) end - - # Attach axes to matrices and arrays of ComponentArrays # with ComponentArrays in the first dimensions (e.g. rownames of a matrix or array) function ComponentArrayInterpreter( - ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where N + ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where {N} ComponentArrayInterpreter(CA.getaxes(ca), n_dims) end function ComponentArrayInterpreter( - cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where N + cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where {N} ComponentArrayInterpreter(CA.getaxes(cai), n_dims) end function ComponentArrayInterpreter( - axes::NTuple{M, <:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N} + axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N} axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...) ComponentArrayInterpreter(axes_ext) end +# support also for other AbstractComponentArrayInterpreter types +# in a type-stable way by providing the Tuple of dimensions as a value type +""" + stack_ca_int(cai::AbstractComponentArrayInterpreter, ::Val{n_dims}) + +Interpret the first dimension of an Array as a ComponentArray. Provide the Tuple +of following dimensions by a value type, e.g. `Val((n_col, n_z))`. +""" +function stack_ca_int( + cai::IT, ::Val{n_dims}) where {IT<:AbstractComponentArrayInterpreter,n_dims} + @assert n_dims isa NTuple{N,<:Integer} where {N} + IT.name.wrapper(CA.getaxes(cai), n_dims)::IT.name.wrapper +end +function StaticComponentArrayInterpreter( + axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N} + axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...) + StaticComponentArrayInterpreter{axes_ext}() +end + # with ComponentArrays in the last dimensions (e.g. columnnames of a matrix) function ComponentArrayInterpreter( - n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where N + n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where {N} ComponentArrayInterpreter(n_dims, CA.getaxes(ca)) end function ComponentArrayInterpreter( - n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where N + n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where {N} ComponentArrayInterpreter(n_dims, CA.getaxes(cai)) end function ComponentArrayInterpreter( - n_dims::NTuple{N,<:Integer}, axes::NTuple{M, <:CA.AbstractAxis}) where {N,M} + n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M} axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...) ComponentArrayInterpreter(axes_ext) end +function stack_ca_int( + ::Val{n_dims}, cai::IT) where {IT<:AbstractComponentArrayInterpreter,n_dims} + @assert n_dims isa NTuple{N,<:Integer} where {N} + IT.name.wrapper(n_dims, CA.getaxes(cai))::IT.name.wrapper +end +function StaticComponentArrayInterpreter( + n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M} + axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...) + StaticComponentArrayInterpreter{axes_ext}() +end + # ambuiguity with two empty Tuples (edge prob that does not make sense) # Empty ComponentVector with no other array dimensions -> empty componentVector function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{}) - ComponentArrayInterpreter(CA.ComponentVector()) + ComponentArrayInterpreter((CA.Axis(),)) +end +function StaticComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{}) + StaticComponentArrayInterpreter{(CA.Axis(),)}() end +# concatenate several 1d ComponentArrayInterpreters +function compose_interpreters(; kwargs...) + compose_interpreters(values(kwargs)) +end +function compose_interpreters(ints::NamedTuple) + axtuples = map(x -> CA.getaxes(x), ints) + axc = compose_axes(axtuples) + intc = ComponentArrayInterpreter((axc,)) + return (intc) +end # not exported, but required for testing _get_ComponentArrayInterpreter_axes(::StaticComponentArrayInterpreter{AX}) where {AX} = AX _get_ComponentArrayInterpreter_axes(cai::ComponentArrayInterpreter) = cai.axes - _axis_length(ax::CA.AbstractAxis) = lastindex(ax) - firstindex(ax) + 1 _axis_length(::CA.FlatAxis) = 0 _axis_length(::CA.UnitRange) = 0 @@ -199,7 +266,6 @@ function flatten1(cv::CA.ComponentVector) end end - """ get_positions(cai::AbstractComponentArrayInterpreter) @@ -207,7 +273,36 @@ Create a NamedTuple of integer indices for each component. Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector. """ function get_positions(cai::AbstractComponentArrayInterpreter) - @assert length(CA.getaxes(cai)) == 1 + #@assert length(CA.getaxes(cai)) == 1 cv = cai(1:length(cai)) - (; (k => cv[k] for k in keys(cv))... ) + keys_cv = keys(cv) + # splatting creates Problems with Zygote + #keys_cv isa Tuple ? (; (k => CA.getdata(cv[k]) for k in keys_cv)...) : CA.getdata(cv) + keys_cv isa Tuple ? NamedTuple{keys_cv}(map(k -> CA.getdata(cv[k]), keys_cv)) : CA.getdata(cv) +end + +function tmpf(v; + cv, + cai::AbstractComponentArrayInterpreter=get_concrete(ComponentArrayInterpreter(cv))) + cai(v) +end + +function tmpf1(v; cai) + caic = get_concrete(cai) + #caic(v) + Test.@inferred tmpf(v, cv=nothing, cai=caic) +end + +function tmpf2(v; cai::AbstractComponentArrayInterpreter) + caic = get_concrete(cai) + #caic = cai + cv = Test.@inferred caic(v) # inferred inside tmpf2 + #cv = caic(v) # inferred inside tmpf2 + vv = tmpf(v; cv=nothing, cai=caic) + #vv = tmpf(v; cv) + #cv.x + #sum(cv) # not inferred on Union cv (axis not know) + #cv.x::AbstractVector{eltype(vv)} # not sufficient + # need to specify concrete return type, but can rely on eltype + sum(vv)::eltype(vv) # need to specify return type end diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 1f9e3d3..bbdb986 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -9,7 +9,7 @@ const θP_nor0 = θP[(:K2,)] const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1] const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0] -int_xP1 = ComponentArrayInterpreter(CA.ComponentVector(S1=xP_S1, S2=xP_S2)) +int_xP1 = ComponentArrayInterpreter(CA.ComponentVector(S1 = xP_S1, S2 = xP_S2)) # const transP = elementwise(exp) # const transM = elementwise(exp) @@ -18,16 +18,16 @@ int_xP1 = ComponentArrayInterpreter(CA.ComponentVector(S1=xP_S1, S2=xP_S2)) const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM))) -function f_doubleMM(θ::AbstractVector, x, intθ) +function f_doubleMM(θ::AbstractVector, x; intθ1) # extract parameters not depending on order, i.e whether they are in θP or θM - y = GPUArraysCore.allowscalar() do - θc = intθ(θ) + y = GPUArraysCore.allowscalar() do + θc = intθ1(θ) #using ComponentArrays: ComponentArrays as CA #r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] # does not work on Zygote+GPU (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par # vector will be repeated when broadcasted by a matrix CA.getdata(θc[par]) - end + end # r0 = θc[:r0] # r1 = θc[:r1] # K1 = θc[:K1] @@ -37,25 +37,28 @@ function f_doubleMM(θ::AbstractVector, x, intθ) return (y) end -function f_doubleMM(θ::AbstractMatrix, x::NamedTuple, intθ::HVI.AbstractComponentArrayInterpreter) +function f_doubleMM( + θ::AbstractMatrix{T}, x; intθ::HVI.AbstractComponentArrayInterpreter) where T # provide θ for n_row sites # provide x.S1 as Matrix n_site x n_obs # extract parameters not depending on order, i.e whether they are in θP or θM - θc = intθ(θ) - @assert size(x.S1,1) == size(θ,1) # same number of sites - @assert size(x.S1) == size(x.S2) # same number of observations - #@assert length(x.s2 == n_obs) - # problems on AD on GPU with indexing CA may be related to printing result, use ";" - (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par - # vector will be repeated when broadcasted by a matrix - CA.getdata(θc[:,par]) - end - # r0 = CA.getdata(θc[:,:r0]) # vector will be repeated when broadcasted by a matrix - # r1 = CA.getdata(θc[:,:r1]) - # K1 = CA.getdata(θc[:,:K1]) - # K2 = CA.getdata(θc[:,:K2]) - y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) - return (y) + θc = intθ(θ) + @assert size(x.S1, 1) == size(θ, 1) # same number of sites + @assert size(x.S1) == size(x.S2) # same number of observations + #@assert length(x.s2 == n_obs) + # problems on AD on GPU with indexing CA may be related to printing result, use ";" + VT = typeof(θ[:,1]) # workaround for non-stable Symbol-indexing CAMatrix + #VT = first(Base.return_types(getindex, Tuple{typeof(θ),typeof(Colon()),typeof(1)})) + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + # vector will be repeated when broadcasted by a matrix + CA.getdata(θc[:, par]) ::VT + end + # r0 = CA.getdata(θc[:,:r0]) # vector will be repeated when broadcasted by a matrix + # r1 = CA.getdata(θc[:,:r1]) + # K1 = CA.getdata(θc[:,:K1]) + # K2 = CA.getdata(θc[:,:K2]) + y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) + return (y) end # function f_doubleMM(θ::AbstractMatrix, x::NamedTuple, θpos::NamedTuple) @@ -80,8 +83,9 @@ end # return (y) # end -function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ()) - if (:omit_r0 ∈ scenario) +function HVI.get_hybridproblem_par_templates( + ::DoubleMMCase; scenario::Val{scen}) where {scen} + if (:omit_r0 ∈ scen) #return ((; θP = θP_nor0, θM, θf = θP[(:K2r)])) return ((; θP = θP_nor0, θM)) end @@ -89,50 +93,69 @@ function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = (; θP, θM) end -function HVI.get_hybridproblem_priors(::DoubleMMCase; scenario = ()) +# function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ()) +# if (:omit_r0 ∈ scenario) +# #return ((; θP = θP_nor0, θM, θf = θP[(:K2r)])) +# return ((; θP = θP_nor0, θM)) +# end +# #(; θP, θM, θf = eltype(θP)[]) +# (; θP, θM) +# end + +function HVI.get_hybridproblem_priors(::DoubleMMCase; scenario::Val{scen}) where {scen} Dict(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) end -function HVI.get_hybridproblem_MLapplicator(prob::HVI.DoubleMM.DoubleMMCase; scenario = ()) +function HVI.get_hybridproblem_MLapplicator( + prob::HVI.DoubleMM.DoubleMMCase; scenario::Val{scen}) where {scen} rng = StableRNGs.StableRNG(111) get_hybridproblem_MLapplicator(rng, prob; scenario) end function HVI.get_hybridproblem_MLapplicator( - rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ()) + rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario::Val{scen}, + use_all_sites = false +) where {scen} ml_engine = select_ml_engine(; scenario) g_nomag, ϕ_g0 = construct_3layer_MLApplicator(rng, prob, ml_engine; scenario) # construct normal distribution from quantiles at unconstrained scale priors_dict = get_hybridproblem_priors(prob; scenario) + (; θM) = get_hybridproblem_par_templates(prob; scenario) priors = [priors_dict[k] for k in keys(θM)] (; transM) = get_hybridproblem_transforms(prob; scenario) - g = NormalScalingModelApplicator(g_nomag, priors, transM, eltype(ϕ_g0)) + lowers, uppers = HVI.get_quantile_transformed( + priors::AbstractVector{<:Distribution}, transM) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + n_site_batch = use_all_sites ? n_site : n_batch + g = NormalScalingModelApplicator( + g_nomag, lowers, uppers, eltype(ϕ_g0)) return g, ϕ_g0 end -function HVI.get_hybridproblem_pbmpar_covars(::DoubleMMCase; scenario) - if (:covarK2 ∈ scenario) +function HVI.get_hybridproblem_pbmpar_covars( + ::DoubleMMCase; scenario::Val{scen}) where {scen} + if (:covarK2 ∈ scen) return (:K2,) end () end - -function HVI.get_hybridproblem_transforms(prob::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridproblem_transforms( + prob::DoubleMMCase; scenario::Val{scen}) where {scen} _θP, _θM = get_hybridproblem_par_templates(prob; scenario) - if (:stackedMS ∈ scenario) - return (; transP = Stacked((HVI.Exp(),),(1:length(_θP),)), - transM = Stacked((identity,HVI.Exp(),),(1:1, 2:length(_θM),))) - elseif (:transIdent ∈ scenario) + if (:stackedMS ∈ scen) + return (; transP = Stacked((HVI.Exp(),), (1:length(_θP),)), + transM = Stacked((identity, HVI.Exp()), (1:1, 2:length(_θM)))) + elseif (:transIdent ∈ scen) # identity transformations, should AD on GPU - return (; transP = Stacked((identity,),(1:length(_θP),)), - transM = Stacked((identity,),(1:length(_θM),))) + return (; transP = Stacked((identity,), (1:length(_θP),)), + transM = Stacked((identity,), (1:length(_θM),))) end - (; transP = Stacked((HVI.Exp(),),(1:length(_θP),)), - transM = Stacked((HVI.Exp(),),(1:length(_θM),))) + (; transP = Stacked((HVI.Exp(),), (1:length(_θP),)), + transM = Stacked((HVI.Exp(),), (1:length(_θM),))) end -# function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario = ()) +# function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario::Val{scen}) where scen # n_covar_pc = 2 # n_covar = n_covar_pc + 3 # linear dependent # #n_site = 10^n_covar_pc @@ -151,17 +174,17 @@ end # intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) # let θFix = gdev(θFix), intθ = get_concrete(intθ) # function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) -# pred_sites = applyf(f_doubleMM, θMs, θP, θFix, xP, intθ) +# pred_sites = map_f_each_site(f_doubleMM, θMs, θP, θFix, xP, intθ) # pred_global = eltype(pred_sites)[] # return pred_global, pred_sites # end # end # end -function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = (), - use_all_sites = false, - gdev = :f_on_gpu ∈ scenario ? gpu_device() : identity, - ) +function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::Val{scen}, + use_all_sites = false, + gdev = :f_on_gpu ∈ HVI._val_value(scenario) ? gpu_device() : identity +) where {scen} n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) n_site_batch = use_all_sites ? n_site : n_batch #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers @@ -170,14 +193,15 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = () θFix = repeat(θFix1', n_site_batch) intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1)) #int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1) - isP = repeat(axes(par_templates.θP,1)', n_site_batch) - let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP=isP, - n_site_batch=n_site_batch, + isP = repeat(axes(par_templates.θP, 1)', n_site_batch) + let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP = isP, + n_site_batch = n_site_batch, #int_xPb=get_concrete(int_xPb), pos_xP = get_positions(int_xP1) + function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) - @assert size(xP,2) == n_site_batch - @assert size(θMs,2) == n_site_batch + @assert size(xP, 2) == n_site_batch + @assert size(θMs, 1) == n_site_batch # # convert vector of tuples to tuple of matricesByRows # # need to supply xP as vectorOfTuples to work with DataLoader # # k = first(keys(xP[1])) @@ -191,10 +215,11 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = () # make sure the same order of columns as in intθ # reshape big matrix into NamedTuple of drivers S1 and S2 # for broadcasting need sites in rows - xPM = map(p -> CA.getdata(xP[p,:])', pos_xP) + #xPM = map(p -> CA.getdata(xP[p,:])', pos_xP) + xPM = map(p -> CA.getdata(xP)'[:, p], pos_xP) θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix - θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs)', θFixd) - pred_sites = f_doubleMM(θ, xPM, intθ)' + θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs), θFixd) + pred_sites = f_doubleMM(θ, xPM; intθ)' pred_global = eltype(pred_sites)[] return pred_global, pred_sites end @@ -207,8 +232,7 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = () end end - -function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::Val) neg_logden_indep_normal end @@ -216,50 +240,52 @@ end # return Float32 # end - # two observations more? # const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.1] # const xP_S2 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0] -HVI.get_hybridproblem_n_covar(prob::DoubleMMCase; scenario) = 5 -function HVI.get_hybridproblem_n_site_and_batch(prob::DoubleMMCase; scenario) +HVI.get_hybridproblem_n_covar(prob::DoubleMMCase; scenario::Val) = 5 +function HVI.get_hybridproblem_n_site_and_batch(prob::DoubleMMCase; + scenario::Val{scen}) where {scen} n_batch = 20 n_site = 800 - if (:few_sites ∈ scenario) - n_site = 100 - elseif (:sites20 ∈ scenario) - n_site = 20 + if (:few_sites ∈ scen) + n_site = 100 + elseif (:sites20 ∈ scen) + n_site = 20 end (n_site, n_batch) end -function HVI.get_hybridproblem_train_dataloader(prob::DoubleMMCase; scenario = (), - rng::AbstractRNG = StableRNG(111), kwargs... - ) +function HVI.get_hybridproblem_train_dataloader(prob::DoubleMMCase; scenario::Val, + rng::AbstractRNG = StableRNG(111), kwargs... +) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) construct_dataloader_from_synthetic(rng, prob; scenario, n_batch, kwargs...) end function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; - scenario = ()) + scenario::Val{scen}) where {scen} n_covar_pc = 2 n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) n_covar = get_hybridproblem_n_covar(prob; scenario) n_θM = length(θM) FloatType = get_hybridproblem_float_type(prob; scenario) par_templates = get_hybridproblem_par_templates(prob; scenario) + #XXTODO transform θMs_true xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM; rhodec = 8, is_using_dropout = false) int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,)) # normalize to be distributed around the prescribed true values θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1))) - f = get_hybridproblem_PBmodel(prob; scenario, gdev=identity, use_all_sites = true) + f = get_hybridproblem_PBmodel(prob; scenario, gdev = identity, use_all_sites = true) #xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site) - int_xPn = ComponentArrayInterpreter(int_xP1, (n_site,)) - xP = int_xPn(vcat(repeat(xP_S1,1,n_site),repeat(xP_S2,1,n_site))) + int_xP_sites = ComponentArrayInterpreter(int_xP1, (n_site,)) + xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site))) #xP[:S1,:] θP = par_templates.θP - y_global_true, y_true = f(θP, θMs_true, xP) + #θint = ComponentArrayInterpreter( (size(θMs_true,2),), CA.getaxes(vcat(θP, θMs_true[:,1]))) + y_global_true, y_true = f(θP, θMs_true', xP) σ_o = FloatType(0.01) #σ_o = FloatType(0.002) logσ2_o = FloatType(2) .* log.(σ_o) @@ -278,3 +304,4 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; y_unc = fill(logσ2_o, size(y_o)) ) end + diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 5be6a91..a2c30eb 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -1,21 +1,21 @@ struct HybridProblem <: AbstractHybridProblem - θP::Any - θM::Any + θP::CA.ComponentVector + θM::CA.ComponentVector f_batch::Any f_allsites::Any - g::Any - ϕg::Any - ϕunc::Any - priors::Any - py::Any - transM::Any - transP::Any - cor_ends::Any # = (P=(1,),M=(1,)) - train_dataloader::Any + g::AbstractModelApplicator + ϕg::Any # depends on framework + ϕunc::CA.ComponentVector + priors::AbstractDict + py::Any # any callable + transM::Stacked + transP::Stacked + cor_ends::@NamedTuple{P::Vector{Int}, M::Vector{Int}} # = (P=(1,),M=(1,)) + train_dataloader::MLUtils.DataLoader n_covar::Int n_site::Int n_batch::Int - pbm_covars::NTuple + pbm_covars::NTuple{_N, Symbol} where _N #inner constructor to constrain the types function HybridProblem( θP::CA.ComponentVector, θM::CA.ComponentVector, @@ -24,9 +24,9 @@ struct HybridProblem <: AbstractHybridProblem f_batch::Function, f_allsites::Function, priors::AbstractDict, - py::Function, - transM::Union{Function, Bijectors.Transform}, - transP::Union{Function, Bijectors.Transform}, + py, + transM::Stacked, + transP::Stacked, # return a function that constructs the trainloader based on n_batch train_dataloader::MLUtils.DataLoader, n_covar::Int, @@ -93,33 +93,33 @@ function update(prob::HybridProblem; train_dataloader, n_covar, n_site, n_batch, cor_ends, pbm_covars) end -function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_par_templates(prob::HybridProblem; scenario = ()) (; θP = prob.θP, θM = prob.θM) end -function get_hybridproblem_ϕunc(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_ϕunc(prob::HybridProblem; scenario = ()) prob.ϕunc end -function get_hybridproblem_neg_logden_obs(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_neg_logden_obs(prob::HybridProblem; scenario = ()) prob.py end -function get_hybridproblem_transforms(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_transforms(prob::HybridProblem; scenario = ()) (; transP = prob.transP, transM = prob.transM) end -# function get_hybridproblem_sizes(prob::HybridProblem; scenario::NTuple = ()) +# function get_hybridproblem_sizes(prob::HybridProblem; scenario = ()) # n_θM = length(prob.θM) # n_θP = length(prob.θP) # (; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP) # end -function get_hybridproblem_PBmodel(prob::HybridProblem; scenario::NTuple = (), use_all_sites=false) +function get_hybridproblem_PBmodel(prob::HybridProblem; scenario = (), use_all_sites=false) use_all_sites ? prob.f_allsites : prob.f_batch end -function get_hybridproblem_MLapplicator(prob::HybridProblem; scenario::NTuple = ()) +function get_hybridproblem_MLapplicator(prob::HybridProblem; scenario = ()) prob.g, prob.ϕg end @@ -144,6 +144,20 @@ function get_hybridproblem_priors(prob::HybridProblem; scenario = ()) prob.priors end -# function get_hybridproblem_float_type(prob::HybridProblem; scenario::NTuple = ()) +# function get_hybridproblem_float_type(prob::HybridProblem; scenario = ()) # eltype(prob.θM) # end + +""" +Get the inverse-transformation of lower and upper quantiles of a Vector of Distributions. + +This can be used to get proper confidence intervals at unconstrained (log) ζ-scale +for priors on normal θ-scale for constructing a NormalScalingModelApplicator. +""" +function get_quantile_transformed(priors::AbstractVector{<:Distribution}, trans; + q95 = (0.05, 0.95)) + θq = ([quantile(d, q) for d in priors] for q in q95) + lowers, uppers = inverse(trans).(θq) +end + + diff --git a/src/HybridSolver.jl b/src/HybridSolver.jl index 60eb1dc..f33e632 100644 --- a/src/HybridSolver.jl +++ b/src/HybridSolver.jl @@ -7,18 +7,20 @@ end HybridPointSolver(; alg) = HybridPointSolver(alg) function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolver; - scenario, rng = Random.default_rng(), - gdev = :use_gpu ∈ scenario ? gpu_device() : identity, - cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, - kwargs...) + scenario, rng=Random.default_rng(), + gdev=:use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, + cdev=gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, + is_inferred::Val{is_infer} = Val(false), + kwargs... +) where is_infer par_templates = get_hybridproblem_par_templates(prob; scenario) g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) FT = get_hybridproblem_float_type(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) intϕ = ComponentArrayInterpreter(CA.ComponentVector( - ϕg = 1:length(ϕg0), ϕP = par_templates.θP)) + ϕg=1:length(ϕg0), ϕP=par_templates.θP)) #ϕ0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true - ϕ0_cpu = vcat(ϕg0, apply_preserve_axes(inverse(transP),par_templates.θP)) + ϕ0_cpu = vcat(ϕg0, apply_preserve_axes(inverse(transP), par_templates.θP)) train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ0_cpu) @@ -29,27 +31,33 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve g_dev = g train_loader_dev = train_loader end - f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=false) y_global_o = FT[] # TODO pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) #intP = ComponentArrayInterpreter(par_templates.θP) - loss_gf = get_loss_gf(g_dev, transM, transP, f, y_global_o, intϕ; cdev, pbm_covars) + loss_gf = get_loss_gf(g_dev, transM, transP, f, y_global_o, intϕ; + cdev, pbm_covars, n_site_batch=n_batch) # call loss function once - l1 = loss_gf(ϕ0_dev, first(train_loader_dev)...)[1] + l1 = is_infer ? + Test.@inferred(loss_gf(ϕ0_dev, first(train_loader_dev)...))[1] : + # using ShareAdd; @usingany Cthulhu + # @descend_code_warntype loss_gf(ϕ0_dev, first(train_loader_dev)...) + loss_gf(ϕ0_dev, first(train_loader_dev)...)[1] # and gradient # xMg, xP, y_o, y_unc = first(train_loader_dev) # gr1 = Zygote.gradient( # p -> loss_gf(p, xMg, xP, y_o, y_unc)[1], # ϕ0_dev) -165 # Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev, data1...)[1], ϕ0_dev) + # Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev, data1...)[1], ϕ0_dev) optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], Optimization.AutoZygote()) optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader_dev) res = Optimization.solve(optprob, solver.alg; kwargs...) ϕ = intϕ(res.u) θP = cpu_ca(apply_preserve_axes(transP, cpu_ca(ϕ).ϕP)) - probo = update(prob; ϕg = cpu_ca(ϕ).ϕg, θP) - (; ϕ, resopt = res, probo) + probo = update(prob; ϕg=cpu_ca(ϕ).ϕg, θP) + (; ϕ, resopt=res, probo) end struct HybridPosteriorSolver{A} <: AbstractHybridSolver @@ -57,22 +65,24 @@ struct HybridPosteriorSolver{A} <: AbstractHybridSolver n_MC::Int n_MC_cap::Int end -function HybridPosteriorSolver(; alg, n_MC = 12, n_MC_cap = n_MC) +function HybridPosteriorSolver(; alg, n_MC=12, n_MC_cap=n_MC) HybridPosteriorSolver(alg, n_MC, n_MC_cap) end function update(solver::HybridPosteriorSolver; - alg = solver.alg, - n_MC = solver.n_MC, - n_MC_cap = n_MC) + alg=solver.alg, + n_MC=solver.n_MC, + n_MC_cap=n_MC) HybridPosteriorSolver(alg, n_MC, n_MC_cap) end function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorSolver; - scenario, rng = Random.default_rng(), - gdev = :use_gpu ∈ scenario ? gpu_device() : identity, - cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, - θmean_quant = 0.0, - kwargs...) + scenario::Val{scen}, rng=Random.default_rng(), + gdev=:use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, + cdev=gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, + θmean_quant=0.0, + is_inferred::Val{is_infer} = Val(false), + kwargs... +) where {scen, is_infer} par_templates = get_hybridproblem_par_templates(prob; scenario) (; θP, θM) = par_templates cor_ends = get_hybridproblem_cor_ends(prob; scenario) @@ -81,8 +91,13 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS (; transP, transM) = get_hybridproblem_transforms(prob; scenario) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + hpints = HybridProblemInterpreters(prob; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, n_batch; transP, transM, ϕunc0) + θP, θM, cor_ends, ϕg0, hpints; transP, transM, ϕunc0) + int_unc = interpreters.unc + int_μP_ϕg_unc = interpreters.μP_ϕg_unc + transMs = StackedArray(transM, n_batch) + # train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ) @@ -93,25 +108,31 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS g_dev = g train_loader_dev = train_loader end - f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=false) py = get_hybridproblem_neg_logden_obs(prob; scenario) - priors_θ_mean = construct_priors_θ_mean( + + priors_θP_mean, priors_θMs_mean = construct_priors_θ_mean( prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM, transP; scenario, get_ca_int_PMs, gdev, cdev, pbm_covars) y_global_o = Float32[] # TODO + loss_elbo = get_loss_elbo( - g_dev, transPMs_batch, f, py, y_global_o, interpreters; - solver.n_MC, solver.n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covars, θP) + g_dev, transP, transMs, f, py, y_global_o; + solver.n_MC, solver.n_MC_cap, cor_ends, priors_θP_mean, priors_θMs_mean, cdev, + pbm_covars, θP, int_unc, int_μP_ϕg_unc) # test loss function once - l0 = loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...) - optf = Optimization.OptimizationFunction((ϕ, data) -> loss_elbo(ϕ, rng, data...)[1], + #Main.@infiltrate_main + l0 = is_infer ? + (Test.@inferred loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...)) : + loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...) + optf = Optimization.OptimizationFunction((ϕ, data) -> first(loss_elbo(ϕ, rng, data...)), Optimization.AutoZygote()) optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader_dev) res = Optimization.solve(optprob, solver.alg; kwargs...) ϕc = interpreters.μP_ϕg_unc(res.u) θP = cpu_ca(apply_preserve_axes(transP, ϕc.μP)) - probo = update(prob; ϕg = cpu_ca(ϕc).ϕg, θP = θP, ϕunc = cpu_ca(ϕc).unc); - (; ϕ = ϕc, θP, resopt = res, interpreters, probo) + probo = update(prob; ϕg=cpu_ca(ϕc).ϕg, θP=θP, ϕunc=cpu_ca(ϕc).unc) + (; ϕ=ϕc, θP, resopt=res, interpreters, probo) end function fit_narrow_normal(θi, prior, θmean_quant) @@ -138,18 +159,27 @@ The loss function takes in addition to ϕ, data that changes with minibatch - `xP`: drivers for the processmodel: Iterator of size n_site - `y_o`, `y_unc`: matrix of observations and uncertainties, sites in columns """ -function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters; - n_MC, n_MC_cap = n_MC, cor_ends, priors_θ_mean, cdev, pbm_covars, θP, - ) - let g = g, transPMs = transPMs, f = f, py = py, y_o_global = y_o_global, n_MC = n_MC, - cor_ends = cor_ends, interpreters = map(get_concrete, interpreters), - priors_θ_mean = priors_θ_mean, cdev = cdev, - pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars) +function get_loss_elbo(g, transP, transMs, f, py, y_o_global; + n_MC, n_MC_mean = max(n_MC,20), n_MC_cap=n_MC, + cor_ends, priors_θP_mean, priors_θMs_mean, cdev, pbm_covars, θP, + int_unc, int_μP_ϕg_unc, +) + let g = g, transP = transP, transMs = transMs, f = f, py = py, y_o_global = y_o_global, + n_MC = n_MC, n_MC_cap = n_MC_cap, n_MC_mean = n_MC_mean, + cor_ends = cor_ends, + int_unc = get_concrete(int_unc), int_μP_ϕg_unc = get_concrete(int_μP_ϕg_unc), + priors_θP_mean = priors_θP_mean, priors_θMs_mean = priors_θMs_mean, cdev = cdev, + pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars), + trans_mP=StackedArray(transP, n_MC_mean), + trans_mMs=StackedArray(transMs.stacked, n_MC_mean) function loss_elbo(ϕ, rng, xM, xP, y_o, y_unc, i_sites) neg_elbo_gtf( - rng, ϕ, g, transPMs, f, py, xM, xP, y_o, y_unc, i_sites, interpreters; - n_MC, n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covar_indices) + rng, ϕ, g, f, py, xM, xP, y_o, y_unc, i_sites; + int_unc, int_μP_ϕg_unc, + n_MC, n_MC_cap, n_MC_mean, cor_ends, priors_θP_mean, priors_θMs_mean, + cdev, pbm_covar_indices, transP, transMs, trans_mP, trans_mMs, + ) end end end @@ -159,11 +189,11 @@ Compute the components of the elbo for given initial conditions of the problems for the first batch of the trainloader. """ function compute_elbo_components( - prob::AbstractHybridProblem, solver::HybridPosteriorSolver; - scenario, rng = Random.default_rng(), gdev = gpu_device(), - θmean_quant = 0.0, - use_all_sites = false, - kwargs...) + prob::AbstractHybridProblem, solver::HybridPosteriorSolver; + scenario, rng=Random.default_rng(), gdev=gpu_device(), + θmean_quant=0.0, + use_all_sites=false, + kwargs...) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) par_templates = get_hybridproblem_par_templates(prob; scenario) (; θP, θM) = par_templates @@ -172,7 +202,7 @@ function compute_elbo_components( ϕunc0 = get_hybridproblem_ϕunc(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, n_batch; transP, transM, ϕunc0) + θP, θM, cor_ends, ϕg0, n_batch; transP, transM, ϕunc0) train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ) @@ -200,21 +230,24 @@ In order to let mean of θ stay close to initial point parameter estimates construct a prior on mean θ to a Normal around initial prediction. """ function construct_priors_θ_mean(prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM, transP; - scenario, get_ca_int_PMs, gdev, cdev, pbm_covars) - iszero(θmean_quant) ? [] : + scenario::Val{scen}, get_ca_int_PMs, gdev, cdev, pbm_covars) where {scen} + iszero(θmean_quant) ? ([],[]) : begin n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) # all_loader = MLUtils.DataLoader( # get_hybridproblem_train_dataloader(prob; scenario).data, batchsize = n_site) # xM_all = first(all_loader)[1] - is_gpu = :use_gpu ∈ scenario + is_gpu = :use_gpu ∈ scen xM_all_cpu = get_hybridproblem_train_dataloader(prob; scenario).data[1] xM_all = is_gpu ? gdev(xM_all_cpu) : xM_all_cpu ζP = apply_preserve_axes(inverse(transP), θP) pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars) - xMP_all = _append_each_covars(xM_all, CA.getdata(ζP), pbm_covar_indices) - #Main.@infiltrate_main - θMs = gtrans(g_dev, transM, xMP_all, CA.getdata(ϕg); cdev = cpu_device()) + xMP_all = _append_each_covars(xM_all, CA.getdata(ζP), pbm_covar_indices) + transMs = StackedArray(transM, n_site) + # ζMs = g_dev(xMP_all, CA.getdata(ϕg))' # transpose to par-last for StackedArray + # ζMs_cpu = cdev(ζMs) + # θMs = transMs(ζMs_cpu) + θMs = gtrans(g_dev, transMs, xMP_all, CA.getdata(ϕg); cdev=cpu_device()) priors_dict = get_hybridproblem_priors(prob; scenario) priorsP = [priors_dict[k] for k in keys(θP)] priors_θP_mean = map(priorsP, θP) do priorsP, θPi @@ -223,12 +256,13 @@ function construct_priors_θ_mean(prob, ϕg, keysθM, θP, θmean_quant, g_dev, priorsM = [priors_dict[k] for k in keysθM] i_par = 1 i_site = 1 - priors_θMs_mean = map(Iterators.product(axes(θMs)...)) do (i_par, i_site) + priors_θMs_mean = map(Iterators.product(axes(θMs)...)) do (i_site, i_par) #@show i_par, i_site - fit_narrow_normal(θMs[i_par, i_site], priorsM[i_par], θmean_quant) + fit_narrow_normal(θMs[i_site, i_par], priorsM[i_par], θmean_quant) end - # concatenate to a flat vector - int_n_site = get_ca_int_PMs(n_site) - int_n_site(vcat(priors_θP_mean, vec(priors_θMs_mean))) + # # concatenate to a flat vector + # int_n_site = get_ca_int_PMs(n_site) + # int_n_site(vcat(priors_θP_mean, vec(priors_θMs_mean))) + priors_θP_mean, priors_θMs_mean end end diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 095e19d..97d2bf5 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -19,11 +19,15 @@ using Optimization using Distributions, DistributionFits using StaticArrays: StaticArrays as SA using Functors +using Test: Test # @inferred +export extend_stacked_nrow, StackedArray #export Exp -include("bijectors_utils.jl") +include("bijectors_utils.jl") -export ComponentArrayInterpreter, flatten1, get_concrete, get_positions +export AbstractComponentArrayInterpreter, ComponentArrayInterpreter, + StaticComponentArrayInterpreter +export flatten1, get_concrete, get_positions, stack_ca_int, compose_interpreters include("ComponentArrayInterpreter.jl") export AbstractModelApplicator, construct_ChainsApplicator @@ -52,10 +56,17 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_ setup_PBMpar_interpreter include("AbstractHybridProblem.jl") +export AbstractHybridProblemInterpreters, HybridProblemInterpreters, + get_int_P, get_int_M, + get_int_Ms_batch, get_int_Ms_site, get_int_Mst_batch, get_int_Mst_site, + get_int_PMs_batch, get_int_PMs_site, get_int_PMst_batch, get_int_PMst_site +include("hybridprobleminterpreters.jl") + export HybridProblem +export get_quantile_transformed include("HybridProblem.jl") -export applyf, gf, get_loss_gf +export map_f_each_site, gf, get_loss_gf include("gf.jl") export compute_correlated_covars, scale_centered_at @@ -73,7 +84,7 @@ include("logden_normal.jl") export get_ca_starts, get_ca_ends, get_cor_count include("cholesky.jl") -export neg_elbo_gtf, predict_gf +export neg_elbo_gtf, predict_hvi include("elbo.jl") export init_hybrid_params, init_hybrid_ϕunc diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 2bf1a01..5f83155 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -52,11 +52,13 @@ end """ construct_3layer_MLApplicator( rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ; - scenario::NTuple = ()) + scenario::Val{scen}) where scen Construct a machine learning model for given Proglem and machine learning engine. Implemented for machine learning extensions, such as Flux or SimpleChains. `ml_engine` usually is of type `Val{Symbol}`, e.g. Val(:Flux). See `select_ml_engine`. + +Scenario is a value-type of `NTuple{_,Symbol}`. """ function construct_3layer_MLApplicator end @@ -68,10 +70,10 @@ Returns a value type `Val{:Symbol}` to dispatch on the machine learning engine t - `:use_Lux ∈ scenario -> Val(:Lux)` - `:use_Flux ∈ scenario -> Val(:Flux)` """ -function select_ml_engine(;scenario) - if :use_Lux ∈ scenario +function select_ml_engine(; scenario::Val{scen}) where scen + if :use_Lux ∈ scen return Val(:Lux) - elseif :use_Flux ∈ scenario + elseif :use_Flux ∈ scen return Val(:Flux) else # default @@ -104,14 +106,15 @@ end NormalScalingModelApplicator(app, priors, transM) Wrapper around AbstractModelApplicator that transforms each output -(assumed in [0..1], usch as output of logistic activation function) +(assumed in [0..1], such as output of logistic activation function) to a quantile of a Normal distribution. Length of μ, σ must correspond to the number of outputs of the wrapped ModelApplicator. -This helps to keep raw ML-predictions (in confidence bounds) and weights in a similar magnitude. -Compared to specifying bounds, this allows for the possibility (although harder to converge) -far beyond the confidence bounds. +This helps to keep raw ML-predictions (in confidence bounds) and weights in a +similar magnitude. +Compared to specifying bounds, this allows for the possibility +(although harder to converge) far beyond the confidence bounds. The second constructor fits a normal distribution of the inverse-transformed 5% and 95% quantiles of prior distributions. @@ -123,19 +126,19 @@ struct NormalScalingModelApplicator{VF,A} <: AbstractModelApplicator end @functor NormalScalingModelApplicator +""" +Fit a Normal distribution to iterators lower and upper. +If `repeat_inner` is given, each fitted distribution is repeated as many times. +""" function NormalScalingModelApplicator( - app::AbstractModelApplicator, priors::AbstractVector{<:Distribution}, transM, ET::Type) - # need to apply transform to entire vectors each of lowers and uppers - θq = ([quantile(d, q) for d in priors] for q in (0.05, 0.95)) - ζlower, ζupper = inverse(transM).(θq) - #ipar = first(axes(ζlower,1)) - pars = map(axes(ζlower,1)) do ipar - dζ = fit(Normal, @qp_l(ζlower[ipar]), @qp_u(ζupper[ipar])) + app::AbstractModelApplicator, lowers::AbstractVector{<:Number}, uppers, ET::Type; repeat_inner::Integer = 1) + pars = map(lowers, uppers) do lower, upper + dζ = fit(Normal, @qp_l(lower), @qp_u(upper)) params(dζ) end # use collect to make it an array that works with gpu - μ = collect(ET, first.(pars)) - σ = collect(ET, last.(pars)) + μ = repeat(collect(ET, first.(pars)); inner=(repeat_inner,)) + σ = repeat(collect(ET, last.(pars)); inner=(repeat_inner,)) NormalScalingModelApplicator(app, μ, σ) end diff --git a/src/bijectors_utils.jl b/src/bijectors_utils.jl index 3cc70ee..efc0c96 100644 --- a/src/bijectors_utils.jl +++ b/src/bijectors_utils.jl @@ -1,3 +1,5 @@ +#------------------- Exp + struct Exp <: Bijector end @@ -16,6 +18,93 @@ end # x = transform(ib, y) # return x, -logabsdetjac(inverse(ib), x) # end +Bijectors.is_monotonically_increasing(::Exp) = true + + +""" + StackedArray(stacked, nrow) + +A Bijectors.Transform that applies stacked to each column of an n-row matrix. +""" +struct StackedArray{S} <: Bijectors.Transform + nrow::Int + stacked::S +end + +function StackedArray(stacked, nrow) + stacked_vec = extend_stacked_nrow(stacked, nrow) + StackedArray{typeof(stacked_vec)}(nrow, stacked_vec) +end + +Functors.@functor StackedArray (stacked,) + +function Base.show(io::IO, b::StackedArray) + return print(io, "StackedArray ($(b.nrow), $(b.stacked))") +end + +function Base.:(==)(b1::StackedArray, b2::StackedArray) + (b1.nrow == b2.nrow) && (b1.stacked == b2.stacked) +end + +Bijectors.isclosedform(b::StackedArray) = isclosedform(b.stacked) + +Bijectors.isinvertible(b::StackedArray) = isinvertible(b.stacked) + +_transform_stackedarray(sb, x) = reshape(sb.stacked(vec(x)), size(x)) +function _transform_stackedarray(sb, x::Adjoint{FT, <:GPUArraysCore.AbstractGPUArray}) where FT + # errors with Zygote for Adjoint of GPUArray, need to copy first + # TODO construct MWE and issue + x_plain = copy(x) + reshape(sb.stacked(vec(x_plain)), size(x_plain)) +end +function Bijectors.transform(sb::StackedArray, x::AbstractArray{<:Real}) + _transform_stackedarray(sb, x) +end + +_logabsdetjac_stackedarray(b,x) = logabsdet(b.stacked, vec(x)) +function Bijectors.logabsdetjac(b::StackedArray, x::AbstractArray{<:Real}) + _logabsdetjac_stackedarray(b,x) +end + +function Bijectors.with_logabsdet_jacobian(sb::StackedArray, x::AbstractArray) + (y, logjac) = with_logabsdet_jacobian(sb.stacked, vec(x)) + ym = reshape(y, size(x)) + return (ym, logjac) +end + +function Bijectors.inverse(sb::StackedArray) + inv_stacked = inverse(sb.stacked) + return StackedArray{typeof(inv_stacked)}(sb.nrow, inv_stacked) +end + +""" + extend_stacked_nrow(b::Stacked, nrow::Integer) + +Create a Stacked bijectors that transforms nrow times the elements +of the original Stacked bijector. + +# Example +``` +X = reduce(hcat, ([x + y for x in 0:4 ] for y in 0:10:30)) +b1 = CP.Exp() +b2 = identity +b = Stacked((b1,b2), (1:1,2:4)) +bs = extend_stacked_nrow(b, size(X,1)) +Xt = reshape(bs(vec(X)), size(X)) +@test Xt[:,1] == b1(X[:,1]) +@test Xt[:,2:4] == b2(X[:,2:4]) +``` +""" +function extend_stacked_nrow(b::Stacked, nrow::Integer) + onet = one(eltype(first(b.ranges_in))) + endpos = last.(b.ranges_in) .* nrow + startpos2 = (endpos[1:(end-1)] .+ onet) + ranges = ntuple(length(endpos)) do i + startpos = i == 1 ? onet : startpos2[i-1] + startpos:endpos[i] + end + bs = Stacked(b.bs, ranges) +end + -Bijectors.is_monotonically_increasing(::Exp) = true diff --git a/src/cholesky.jl b/src/cholesky.jl index 3ea9589..52a94b6 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -15,7 +15,9 @@ function vec2utri(v::AbstractVector{T}; n=invsumn(length(v))) where {T} #https://groups.google.com/g/julia-users/c/UARlZBCNlng/m/6tKKxIeoEY8J z = zero(T) k = 0 - m = [j >= i ? (k += 1; v[k]) : z for i in 1:n, j in 1:n] + #m = T[j >= i ? (k += 1; v[k]) : z for i in 1:n, j in 1:n] # no Zygote + #m = [j >= i ? (k += 1; convert(T,v[k])) : convert(T,z) for i in 1:n, j in 1:n] # no Zygote + m = [j >= i ? (k += 1; v[k]) : z for i in 1:n, j in 1:n]::AbstractMatrix{T} # for typestability UpperTriangular(m) end @@ -65,7 +67,7 @@ function _vec2uutri( v::AbstractVector{T}; n=invsumn(length(v)) + one(T), diag=one(T)) where {T} z = zero(T) k = 0 - m = [j > i ? (k += 1; v[k]) : i == j ? diag : z for i in 1:n, j in 1:n] + m = [j > i ? (k += 1; v[k]) : i == j ? diag : z for i in 1:n, j in 1:n]::AbstractMatrix{T} return (m) end @@ -97,7 +99,7 @@ function utri2vec(X::AbstractMatrix{T}) where {T} X[i, j] end for _ in 1:lv - ] + ]::AbstractVector{T} end """ @@ -108,17 +110,18 @@ function uutri2vec(X::AbstractMatrix{T}) where {T} lv = sumn(n) i = 0 j = 2 - [ + if n == 0; return T[]; end # otherwise Any[] is returned + v = [ begin if i == j - 1 i = 0 j += 1 end i += 1 - X[i, j] + convert(T,X[i, j]) end for _ in 1:lv - ] + ]::AbstractVector{T} end function ChainRulesCore.rrule(::typeof(uutri2vec), X::AbstractMatrix{T}) where {T} @@ -194,7 +197,7 @@ Useful for providing information on correlactions among subranges in a vector. """ function get_ca_ends(vc::CA.ComponentVector) #(cumsum(length(vc[k]) for k in keys(vc))...,) - length(vc) == 0 ? Int[] : cumsum(length(vc[k]) for k in keys(vc)) + length(keys(vc)) == 0 ? Int[] : cumsum(length(vc[k])::Int for k in keys(vc)) end @@ -209,7 +212,7 @@ function get_cor_count(cor_ends::AbstractVector) sum(get_cor_counts(cor_ends)) end function get_cor_counts(cor_ends::AbstractVector{T}) where {T} - isempty(cor_ends) && return (zero(T)) + isempty(cor_ends) && return (zeros(T,1)) cnt_blocks = ( begin i == 1 ? cor_ends[i] : cor_ends[i] - cor_ends[i-1] @@ -233,9 +236,9 @@ the blocks start at columns (3,5,6). It defaults to a single entire block. """ function transformU_block_cholesky1( v::AbstractVector{T}, cor_ends::AbstractVector{TI}=Int[]) where {T,TI<:Integer} - #@show v, cor_ends if length(cor_ends) <= 1 # if there is only one block, return it - return transformU_cholesky1(v) + # for type stability create a BlockDiagonal of a single block + return _create_blockdiag(v, [transformU_cholesky1(v)]) end cor_counts = get_cor_counts(cor_ends) # number of correlation parameters #@show cor_counts diff --git a/src/elbo.jl b/src/elbo.jl index dfa69ea..8db37c0 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -32,55 +32,69 @@ function neg_elbo_gtf(args...; kwargs...) nLy - entropy_ζ + nLmean_θ end -function neg_elbo_gtf_components(rng, ϕ::AbstractVector, g, transPMs, f, py, - xM::AbstractMatrix, xP, y_ob, y_unc, i_sites::AbstractVector{<:Number}, - interpreters::NamedTuple; - n_MC = 12, n_MC_mean = 30, n_MC_cap = n_MC, - cdev = cpu_device(), - priors_θ_mean = [], - cor_ends, # =(P=(1,),M=(1,)) - pbm_covar_indices -) - n_MCr = isempty(priors_θ_mean) ? n_MC : max(n_MC, n_MC_mean) - ζs, σ = generate_ζ(rng, g, ϕ, xM, interpreters; n_MC = n_MCr, cor_ends, pbm_covar_indices) - ζs_cpu = cdev(ζs) # differentiable fetch to CPU in Flux package extension - nLy, entropy_ζ = neg_elbo_ζtf(ζs_cpu, σ, transPMs, f, py, - xP, y_ob, y_unc, interpreters; - n_MC, n_MC_cap - ) - nLmean_θ = isempty(priors_θ_mean) ? 0.0 : - begin - # compute the mean of predicted and transformed site-parameters - # avoid mapslices because of Zygote - # θs0 = mapslices(transPMs, ζs_cpu, dims=[1]) - # θPs0 = mapslices(θ -> interpreters.PMs(θ).P, θs0, dims = 1) - #θs = (transPMs(ζ) for ζ in eachcol(ζs_cpu)) # does not work with Zygote - θs = map(transPMs, eachcol(ζs_cpu)) - θPs = map(θ -> CA.getdata(interpreters.PMs(θ).P), θs) |> stack - # does not need to allocate vectors but does not work with Zygote: - # θPs = (CA.getdata(interpreters.PMs(θ).P) for θ in θs) |> stack - mean_θP = mean(CA.getdata(θPs); dims = (2))[:, 1] - #nLmean_θP = map((d, θi) -> -logpdf(d, θi), CA.getdata(priors_θ_mean.P), mean_θP) - #workaround for Zygote failing on `priors_θ_mean.P` - iθ = CA.ComponentArray(1:length(priors_θ_mean), CA.getaxes(priors_θ_mean)) - # need to apply different dist to each entry in θP and mean_θMs -> @allowscalar - # but does not work with Zygote - nLmean_θP = map((d, θi) -> -logpdf(d, θi), priors_θ_mean[CA.getdata(iθ.P)], mean_θP) - θMss = map(θ -> interpreters.PMs(θ).Ms, θs) |> stack - mean_θMs = mean(θMss; dims = (3))[:, :, 1] - nLmean_θMs = map((d, θi) -> -logpdf(d, θi), - CA.getdata(priors_θ_mean[CA.getdata(iθ.Ms)])[:, i_sites], mean_θMs) - nLmean_θ = sum(nLmean_θP) + sum(nLmean_θMs) - end +function neg_elbo_gtf_components(rng, ϕ::AbstractVector{FT}, g, f, py, + xM::AbstractMatrix, xP, y_ob, y_unc, i_sites::AbstractVector{<:Number}; + int_μP_ϕg_unc::AbstractComponentArrayInterpreter, + int_unc::AbstractComponentArrayInterpreter, + n_MC=12, n_MC_mean=n_MC, n_MC_cap=n_MC, + cdev=cpu_device(), + priors_θP_mean=[], + priors_θMs_mean=[], + #priors_θ_mean=[], + cor_ends, # =(P=(1,),M=(1,)) + pbm_covar_indices, + transP, transMs, + trans_mP =StackedArray(transP, n_MC), # provide with creating cost function + trans_mMs =StackedArray(transMs.stacked, n_MC), +) where {FT} + n_MCr = isempty(priors_θP_mean) ? n_MC : max(n_MC, n_MC_mean) + ζsP, ζsMs, σ = generate_ζ(rng, g, ϕ, xM; n_MC=n_MCr, cor_ends, pbm_covar_indices, + int_unc, int_μP_ϕg_unc) + ζsP_cpu = cdev(ζsP) # fetch to CPU, because for <1000 sites (n_batch) this is faster + ζsMs_cpu = cdev(ζsMs) # fetch to CPU, because for <1000 sites (n_batch) this is faster + # + # maybe: translate ζ once and supply to both neg_elbo and negloglik_meanθ + nLy, entropy_ζ = neg_elbo_ζtf( + ζsP_cpu[:,1:n_MC], ζsMs_cpu[:,:,1:n_MC], σ, f, py, xP, y_ob, y_unc; + n_MC_cap, transP, transMs, ) + # + # maybe: provide trans_mP and trans_mMs with creating cost function + nLmean_θ = _compute_negloglik_meanθ(ζsP_cpu, ζsMs_cpu; + trans_mP, trans_mMs, priors_θP_mean, priors_θMs_mean, i_sites) nLy, entropy_ζ, nLmean_θ end -function neg_elbo_ζtf(ζs, σ, transPMs, f, py, - xP, y_ob, y_unc, interpreters::NamedTuple; - n_MC = 12, n_MC_cap = n_MC +function _compute_negloglik_meanθ(ζsP::AbstractMatrix{FT}, ζsMs; + priors_θP_mean, priors_θMs_mean, i_sites, trans_mP, trans_mMs, +) where FT + if isempty(priors_θP_mean) + return zero(FT) + end + θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) + mean_θP = mean(CA.getdata(θsP); dims=(2))[:, 1] + nLmean_θP = map((d, θi) -> -logpdf(d, θi), priors_θP_mean, mean_θP) + mean_θMs = mean(θsMs; dims=(3))[:, :, 1] + nLmean_θMs = map((d, θi) -> -logpdf(d, θi), priors_θMs_mean[i_sites], mean_θMs) + nLmean_θ = sum(nLmean_θP) + sum(nLmean_θMs) + convert(FT,nLmean_θ)::FT +end + +""" +Compute the neg_elbo for each sampled parameter vector (last dimension of ζs). +- Transform and compute log-jac +- call forward model +- compute log-density of predictions +- compute entropy of transformation +""" +function neg_elbo_ζtf(ζsP, ζsMs, σ, f, py, xP, y_ob, y_unc; + n_MC_cap=size(ζsP,2), + transP, + transMs=StackedArray(transM, size(ζsMs, 2)) ) - nLys = map(eachcol(ζs[:, 1:n_MC])) do ζi - θ_i, y_pred_i, logjac = predict_y(ζi, xP, f, transPMs, interpreters.PMs) + n_MC = size(ζsP,2) + nLys = map(eachcol(ζsP), eachslice(ζsMs; dims=3)) do ζP, ζMs + θP, θMs, logjac = transform_and_logjac_ζ(ζP, ζMs; transP, transMs) + y_pred_global, y_pred_i = f(θP, θMs, xP) # TODO nLogDen prior on \theta #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) nLy1 = py(y_ob, y_pred_i, y_unc) @@ -98,7 +112,9 @@ function neg_elbo_ζtf(ζs, σ, transPMs, f, py, # sum_log_σ = sum(log.(σ)) # logdet_jacT2 = -sum_log_σ # log Prod(1/σ_i) = -sum log σ_i logdetΣ = 2 * sum(log.(σ)) - entropy_ζ = entropy_MvNormal(size(ζs, 1), logdetΣ) # defined in logden_normal + n_θ = size(ζsP, 1) + prod(size(ζsMs)[1:2]) + @assert length(σ) == n_θ + entropy_ζ = entropy_MvNormal(n_θ, logdetΣ) # defined in logden_normal # if i_sites[1] == 1 # #Main.@infiltrate_main # @show nLy, entropy_ζ, nLmean_θ, n_MC, n_MC_cap, i_sites[1:3] @@ -111,7 +127,7 @@ end () -> begin nLy = reduce( +, map(eachcol(ζs_cpu[:, 1:n_MC])) do ζi - θ_i, y_pred_i, logjac = predict_y(ζi, xP, f, transPMs, interpreters.PMs) + θ_i, y_pred_i, logjac = apply_f_trans(ζi, xP, f, transPMs, interpreters.PMs) # TODO nLogDen prior on \theta #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) nLy1 = py(y_ob, y_pred_i, y_unc) @@ -125,31 +141,53 @@ end end """ - predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters; + predict_hvi([rng], prob::AbstractHybridProblem [,xM, xP]; scenario, ...) + predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix; get_transPMs, get_ca_int_PMs, n_sample_pred=200, gdev = identity) -Prediction function for hybrid model. Returns an NamedTuple with entries -- `θ`: ComponentArray `(n_θP + n_site * n_θM), n_sample_pred)` of PBM model parameters. +Prediction function for hybrid variational inference parameter model. + +## Arguments +- The problem for which to predict +- xM: covariates for the machine-learning model (ML): Matrix (n_θM x n_site_pred). +- xP: model drivers for process based model (PBM): Matrix with (n_site_pred) rows. + If provided a ComponentArray with a Tuple-Axis in rows, the PBM model can + access parts of it, e.g. `xP[:S1,...]`. + +## Keyword arguments +- scenario +- n_sample_pred + +Returns an NamedTuple `(; y, θsP, θsMs, entropy_ζ)` with entries - `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions. +- `θsP`: ComponentArray `(n_θP, n_sample_pred)` of PBM model parameters + that are kept constant across sites. +- `θsMs`: ComponentArray `(n_site, n_θM, n_sample_pred)` of PBM model parameters + that vary by site. +- `entropy_ζ`: The entropy of the log-determinant of the transformation of + the set of model parameters, which is involved in uncertainty quantification. """ -function predict_gf(rng, prob::AbstractHybridProblem; scenario, kwargs...) +function predict_hvi(rng, prob::AbstractHybridProblem; scenario, kwargs...) dl = get_hybridproblem_train_dataloader(prob; scenario) dl_dev = gdev_hybridproblem_dataloader(dl; scenario) + # predict for all sites xM, xP = dl_dev.data[1:2] - predict_gf(rng, prob, xM, xP; scenario, kwargs...) + predict_hvi(rng, prob, xM, xP; scenario, kwargs...) end -function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; - scenario, - n_sample_pred = 200, - gdev = :use_gpu ∈ scenario ? gpu_device() : identity, - cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity +function predict_hvi(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; + scenario, + n_sample_pred=200, + gdev=:use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, + cdev=!(gdev isa MLDataDevices.AbstractGPUDevice) ? identity : + (:f_on_gpu ∈ _val_value(scenario) ? identity : cpu_device()), + kwargs... ) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) is_predict_batch = (n_batch == length(xP)) n_site_pred = is_predict_batch ? n_batch : n_site - @assert length(xP) == n_site_pred + @assert size(xP, 2) == n_site_pred @assert size(xM, 2) == n_site_pred - f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = !is_predict_batch) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=!is_predict_batch) par_templates = get_hybridproblem_par_templates(prob; scenario) (; θP, θM) = par_templates cor_ends = get_hybridproblem_cor_ends(prob; scenario) @@ -158,76 +196,129 @@ function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; (; transP, transM) = get_hybridproblem_transforms(prob; scenario) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars) - (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, n_site_pred; transP, transM, ϕunc0) + hpints = HybridProblemInterpreters(prob; scenario) + (; ϕ, transPMs_batch, interpreters, get_transPMs) = init_hybrid_params( + θP, θM, cor_ends, ϕg0, hpints; transP, transM, ϕunc0) + int_μP_ϕg_unc = interpreters.μP_ϕg_unc + int_unc = interpreters.unc + transMs = StackedArray(transM, n_batch) g_dev, ϕ_dev = gdev(g), gdev(ϕ) - predict_gf(rng, g_dev, f, ϕ_dev, xM, xP, interpreters; - get_transPMs, get_ca_int_PMs, n_sample_pred, cdev, cor_ends, pbm_covar_indices) + predict_hvi(rng, g_dev, f, ϕ_dev, xM, xP; + int_μP_ϕg_unc, int_unc, transP, transM, + n_sample_pred, cdev, cor_ends, pbm_covar_indices, kwargs...) end -function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP, interpreters; - get_transPMs, get_ca_int_PMs, n_sample_pred = 200, - cdev = cpu_device(), - cor_ends, #cor_ends=(P=(1,),M=(1,)) - pbm_covar_indices -) - n_site = size(xM, 2) - intm_PMs_gen = get_ca_int_PMs(n_site) - trans_PMs_gen = get_transPMs(n_site) - interpreters_gen = (; interpreters..., PMs = intm_PMs_gen) - ζs_gpu, σ = generate_ζ(rng, g, CA.getdata(ϕ), CA.getdata(xM), interpreters_gen; - n_MC = n_sample_pred, cor_ends, pbm_covar_indices) - ζs = cdev(ζs_gpu) +function predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP; + int_μP_ϕg_unc::AbstractComponentArrayInterpreter, + int_unc::AbstractComponentArrayInterpreter, + transP, transM, + n_sample_pred=200, + cdev=cpu_device(), + cor_ends, + pbm_covar_indices, + is_inferred::Val{is_infer} = Val(false), + kwargs... +) where is_infer + ζsP, ζsMs, σ = generate_ζ(rng, g, CA.getdata(ϕ), CA.getdata(xM); + int_μP_ϕg_unc, int_unc, + n_MC=n_sample_pred, cor_ends, pbm_covar_indices) + ζsP_cpu = cdev(ζsP) + ζsMs_cpu = cdev(ζsMs) logdetΣ = 2 * sum(log.(σ)) - entropy_ζ = entropy_MvNormal(length(σ), logdetΣ) # defined in logden_normal - (; θ, y) = predict_ζf(ζs, f, xP, trans_PMs_gen, interpreters_gen.PMs) - (; θ, y, entropy_ζ) + entropy_ζ = entropy_MvNormal(length(σ), logdetΣ) + res_pred = is_infer ? + apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) : + Test.@inferred apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) + (; res_pred..., entropy_ζ) end -function predict_ζf(ζs, f, xP, trans_PMs, interpreter_PMs) - θandy = map(eachcol(ζs)) do ζ - θandy_i= predict_y(ζ, xP, f, trans_PMs, interpreter_PMs)[1:2]; - end - θ1 = first(first(θandy)) - θ = CA.ComponentMatrix( - stack(CA.getdata.(first.(θandy))), (CA.getaxes(θ1)[1], CA.FlatAxis())) - #θ[:P,1] - y = stack(last.(θandy)) - (; θ, y) + +""" +Compute predictions of the transformation at given +transformed parameters for each site. +The number of sites is given by the number of rows in `ζsMs`. + +Steps: +- transform the parameters to original constrained space +- Applies the mechanistic model for each site + +`ζsP` and `ζsMs` are shaped according to the output of `generate_ζ`. +Results are of shape `(n_obs x n_site_pred x n_MC)`. +""" +function apply_f_trans(ζsP::AbstractMatrix, ζsMs::AbstractArray, f, xP; + transP, transM::Stacked, + trans_mP=StackedArray(transP, size(ζsP, 2)), + trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3)) +) + θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) + y = apply_f(θsP, θsMs, f, xP) + (; y, θsP, θsMs) +end + +function apply_f_trans(ζP::AbstractVector, ζMs::AbstractMatrix, f, xP; + transP, transM::Stacked, transMs::StackedArray=StackedArray(transM, size(ζMs, 1)), +) + θP = transP(ζP) + θMs = transMs(ζMs) + y_global, y = f(θP, θMs, xP) + (; y, θP, θMs) +end + + +function apply_f(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET + y_pred = stack(map(eachcol(θsP), eachslice(θsMs, dims=3)) do θP, θMs + y_global, y_pred_i = f(θP, θMs, xP) + y_pred_i + end) end """ Generate samples of (inv-transformed) model parameters, ζ, and the vector of standard deviations, σ, i.e. the diagonal of the cholesky-factor. -Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0` +Adds the MV-normally distributed residuals, retrieved by `sample_ζresid_norm` to the means extracted from parameters and predicted by the machine learning model. + +The output shape of size `(n_site x n_par x n_MC)` is tailored to iterating +each MC sample and then transforming each parameter on block across sites. """ -function generate_ζ(rng, g, ϕ::AbstractVector, xM::AbstractMatrix, - interpreters::NamedTuple; n_MC = 3, cor_ends, pbm_covar_indices) -# see documentation of neg_elbo_gtf -ϕc = interpreters.μP_ϕg_unc(CA.getdata(ϕ)) -μ_ζP = ϕc.μP -ϕg = ϕc.ϕg -# first pass: append μ_ζP_to covars, need ML prediction for magnitude of ζMs -xMP = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices) -μ_ζMs0 = g(xMP, ϕg) -ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends) -#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC) -# @show size(ζ_resid) -# @show length(interpreters.PMs) -ζ = stack(map(eachcol(ζ_resid)) do r - rc = interpreters.PMs(r) - ζP = μ_ζP .+ rc.P - # second pass: append ζP rather than μ_ζP to covarss to xM - μ_ζMs = _predict_μ_ζMs(xM, ζP, pbm_covar_indices, g, ϕg, μ_ζMs0) - ζMs = μ_ζMs .+ rc.Ms - vcat(ζP, vec(ζMs)) -end) -ζ, σ -end +function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT; + int_μP_ϕg_unc::AbstractComponentArrayInterpreter, + int_unc::AbstractComponentArrayInterpreter, + n_MC=3, cor_ends, pbm_covar_indices) where {FT,MT} + # see documentation of neg_elbo_gtf + ϕc = int_μP_ϕg_unc(CA.getdata(ϕ)) + μ_ζP = CA.getdata(ϕc.μP) + ϕg = CA.getdata(ϕc.ϕg) + # first pass: append μ_ζP_to covars, need ML prediction for magnitude of ζMs + # TODO replace pbm_covar_indices by ComponentArray? dimensions to be type-inferred? + xMP0 = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices) + #Main.@infiltrate_main + μ_ζMs0 = g(xMP0, ϕg)::MT # for gpu restructure returns Any, so apply type + ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends, int_unc) + if pbm_covar_indices isa SA.SVector{0} + # do not need to predict again but just add the residuals to μ_ζP and μ_ζMs + ζsP = (μ_ζP .+ ζP_resids) # n_par x n_MC + ζsMs = permutedims(μ_ζMs0 .+ ζMs_parfirst_resids, (2, 1, 3)) # n_site x n_par x n_MC + else + #rP, rMs = first(zip(eachcol(ζP_resids), eachslice(ζMs_parfirst_resids;dims=3))) + ζst = map(eachcol(ζP_resids), eachslice(ζMs_parfirst_resids; dims=3)) do rP, rMs + ζP = μ_ζP .+ rP + # second pass: append ζP rather than μ_ζP to covars to xM + xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) + μ_ζMst = g(xMP, ϕg)::MT # for gpu restructure returns Any, so apply type + ζMs = (μ_ζMst .+ rMs)' # already transform to par-last form + ζP, ζMs + end + # ζsP = stack(map(first, ζst); dims=1) # n_MC x n_par + # ζsMs = stack(map(x -> x[2], ζst); dims=1) # n_MC x n_site x n_par + ζsP = stack(map(first, ζst)) # n_par x n_MC + ζsMs = stack(map(x -> x[2], ζst)) # n_site x n_par x n_MC + end + ζsP, ζsMs, σ +end # function _append_PBM_covars(xM, ζP, pbm_covars::NTuple{N,Symbol}) where N # #@show ζP, typeof(ζP) @@ -240,29 +331,28 @@ end # xM # end -function _append_each_covars(xM, ζP::AbstractVector, pbm_covar_indices::SA.StaticVector{0}) +function _append_each_covars(xM, ζP::AbstractVector, pbm_covar_indices::SA.StaticVector{0}) xM end -function _append_each_covars(xM, ζP::AbstractVector, pbm_covar_indices::AbstractVector) +function _append_each_covars(xM, ζP::AbstractVector, pbm_covar_indices::AbstractVector) ζP_covar = ζP[pbm_covar_indices] _append_each_covars(xM, ζP_covar) end -function _append_each_covars(xM, ζP_covar::AbstractVector) +function _append_each_covars(xM, ζP_covar::AbstractVector) #@show ζP, typeof(ζP) @assert eltype(xM) == eltype(ζP_covar) #Main.@infiltrate_main - ζP_rep = reduce(hcat, fill(ζP_covar, size(xM,2))) - vcat(xM,ζP_rep) + ζP_rep = reduce(hcat, fill(ζP_covar, size(xM, 2))) + vcat(xM, ζP_rep) end - -function get_pbm_covar_indices(ζP, pbm_covars::NTuple{N,Symbol}, - intP::AbstractComponentArrayInterpreter = ComponentArrayInterpreter(ζP)) where N +function get_pbm_covar_indices(ζP, pbm_covars::NTuple{N,Symbol}, + intP::AbstractComponentArrayInterpreter=ComponentArrayInterpreter(ζP)) where {N} #SA.SVector{N}(CA.getdata(intP(1:length(intP))[pbm_covars])) # can not index into GPUarr CA.getdata(intP(1:length(intP))[pbm_covars]) end -function get_pbm_covar_indices(ζP, pbm_covars::NTuple{0}, - intP::AbstractComponentArrayInterpreter = ComponentArrayInterpreter(ζP)) +function get_pbm_covar_indices(ζP, pbm_covars::NTuple{0}, + intP::AbstractComponentArrayInterpreter=ComponentArrayInterpreter(ζP)) SA.SA[] end @@ -277,7 +367,7 @@ end # end function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0) - xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) + xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) μ_ζMs = g(xMP2, ϕg) end function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0) @@ -286,26 +376,35 @@ function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕ μ_ζMs0 end - - """ -Extract relevant parameters from θ and return n_MC generated draws -together with the vector of standard deviations, σ. +Extract relevant parameters from ζ and return n_MC generated multivariate normal draws +together with the vector of standard deviations, `σ`: `(ζP_resids, ζMs_parfirst_resids, σ)` +The output shape `(n_θ, n_site?, n_MC)` is tailored to adding `ζMs_parfirst_resids` to +ML-model predcitions of size `(n_θM, n_site)`. ## Arguments -`int_unc`: Interpret vector as ComponentVector with components - ρsP, ρsM, logσ2_logP, coef_logσ2_logMs(intercept + slope), +* `int_unc`: Interpret vector as ComponentVector with components + ρsP, ρsM, logσ2_ζP, coef_logσ2_ζMs(intercept + slope), """ -function sample_ζ_norm0(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix, - args...; n_MC, cor_ends) +function sample_ζresid_norm(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix, + args...; n_MC, cor_ends, int_unc) n_θP, n_θMs = length(ζP), length(ζMs) - urand = _create_random(rng, CA.getdata(ζP), n_θP + n_θMs, n_MC) - sample_ζ_norm0(urand, ζP, ζMs, args...; cor_ends) + # intm_PMs_parfirst = !isnothing(intm_PMs_parfirst) ? intm_PMs_parfirst : begin + # n_θM, n_site_batch = size(ζMs) + # get_concrete(ComponentArrayInterpreter( + # P = (n_MC, n_θP), Ms = (n_MC, n_θM, n_site_batch))) + # end + #urandn = _create_randn(rng, CA.getdata(ζP), n_MC, n_θP + n_θMs) + urandn = _create_randn(rng, CA.getdata(ζP), n_θP + n_θMs, n_MC) + sample_ζresid_norm(urandn, CA.getdata(ζP), CA.getdata(ζMs), args...; + cor_ends, int_unc=get_concrete(int_unc)) end -function sample_ζ_norm0(urand::AbstractMatrix, ζP::AbstractVector{T}, ζMs::AbstractMatrix, - ϕunc::AbstractVector, int_unc = ComponentArrayInterpreter(ϕunc); cor_ends -) where {T} +function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM, + ϕunc::AbstractVector; + int_unc=get_concrete(ComponentArrayInterpreter(ϕunc)), + cor_ends +) where {T,TP<:AbstractVector{T},TM<:AbstractMatrix{T}} ϕuncc = int_unc(CA.getdata(ϕunc)) n_θP, n_θMs, (n_θM, n_batch) = length(ζP), length(ζMs), size(ζMs) # do not create a UpperTriangular Matrix of an AbstractGÜUArray in transformU_cholesky1 @@ -313,19 +412,61 @@ function sample_ζ_norm0(urand::AbstractMatrix, ζP::AbstractVector{T}, ζMs::Ab UP = transformU_block_cholesky1(ρsP, cor_ends.P) ρsM = isempty(ϕuncc.ρsM) ? similar(ϕuncc.ρsM) : ϕuncc.ρsM # required by zygote UM = transformU_block_cholesky1(ρsM, cor_ends.M) - cf = ϕuncc.coef_logσ2_logMs + cf = ϕuncc.coef_logσ2_ζMs logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs) - logσ2_logP = vec(CA.getdata(ϕuncc.logσ2_logP)) + logσ2_ζP = vec(CA.getdata(ϕuncc.logσ2_ζP)) # CUDA cannot multiply BlockDiagonal * Diagonal, construct already those blocks σMs = reshape(exp.(logσ2_logMs ./ 2), n_θM, :) - σP = exp.(logσ2_logP ./ 2) + σP = exp.(logσ2_ζP ./ 2) # BlockDiagonal does work with CUDA, but not with combination of Zygote and CUDA # need to construct full matrix for CUDA Uσ = _create_blockdiag(UP, UM, σP, σMs, n_batch) - ζ_resid = Uσ' * urand σ = diag(Uσ) # elements of the diagonal: standard deviations - # returns AbstractGPUuArrays to either continue on GPU or need to transfer to CPU - ζ_resid, σ + n_MC = size(urandn, 2) # TODO transform urandn + ζ_resids_parfirst = Uσ' * urandn # n_par x n_MC + #ζ_resids_parfirst = urandn' * Uσ # n_MC x n_par + ζP_resids = ζ_resids_parfirst[1:n_θP, :] + ζMs_parfirst_resids = reshape(ζ_resids_parfirst[(n_θP+1):end, :], n_θM, n_batch, n_MC) + ζP_resids, ζMs_parfirst_resids, σ + # #map(std, eachcol(ζ_resids_parfirst[:, 3:8])) + # ζ_resid = transpose_mPMs_sitefirst(ζ_resids_parfirst; intm_PMs_parfirst) + # #map(std, eachcol(ζ_resid[:, 3:8])) # all ~ 0.1 in sample_ζresid_norm cpu + # #map(std, eachcol(ζ_resid[:, 2 + n_batch .+ (-1:5)])) # all ~ 100, except first two + # # returns AbstractGPUuArrays to either continue on GPU or need to transfer to CPU + # ζ_resid, σ +end + +""" +Transforms each row of a matrix (n_MC x n_Par) with site parameters Ms inside n_Par +of form (n_par x n_site) to Ms of the form (n_site x n_par), i.e. +neighboring entries (inside a column) are of the same parameter. + +This format of having n_par as the last dimension helps transforming parameters +on block. +""" +function transpose_mPMs_sitefirst(Xt, n_θP::Integer, n_θM, n_site_batch, n_MC) + # cannot make n_θP keyword arguments, because it overrides method below + intm_PMs_parfirst = ComponentArrayInterpreter( + P=(n_MC, n_θP), Ms=(n_MC, n_θM, n_site_batch)) + transpose_mPMs_sitefirst(Xt; intm_PMs_parfirst) +end +function transpose_mPMs_sitefirst(Xt; + intm_PMs_parfirst=ComponentArrayInterpreter( + P=(n_MC, n_θP), Ms=(n_MC, n_θM, n_site_batch)) +) + Xtc = intm_PMs_parfirst(Xt) + # Main.@infiltrate_main + + # _Ms = Xtc.Ms + # map(std, eachrow(_Ms[:,1:6,:])) + + # map(std, eachrow(tmp[3:8,:])) + # _Ms = permutedims(Xtc.Ms, (1, 3, 2)) + X_site_first = CA.ComponentVector(P=Xtc.P, Ms=permutedims(Xtc.Ms, (1, 3, 2))) + reshape(CA.getdata(X_site_first), size(Xt))::typeof(CA.getdata(Xt)) + # X_site_first = CA.ComponentVector( + # P = permutedims(Xtc.P), Ms = permutedims(Xtc.Ms, (3, 2, 1))) + # reshape(CA.getdata(X_site_first), rev(size(Xt)))::typeof(CA.getdata(Xt)) end function _create_blockdiag(UP::AbstractMatrix{T}, UM, σP, σMs, n_batch) where {T} @@ -333,14 +474,14 @@ function _create_blockdiag(UP::AbstractMatrix{T}, UM, σP, σMs, n_batch) where BlockDiagonal(v) end function _create_blockdiag( - UP::GPUArraysCore.AbstractGPUMatrix{T}, UM, σP, σMs, n_batch) where {T} + UP::GPUArraysCore.AbstractGPUMatrix{T}, UM, σP, σMs, n_batch) where {T} # using BlockDiagonal leads to Scalar operations downstream # v = [i == 0 ? UP * Diagonal(σP) : UM * Diagonal(σMs[:, i]) for i in 0:n_batch] # BlockDiagonal(v) # Uσ = cat([i == 0 ? UP * Diagonal(σP) : UM * Diagonal(σMs[:, i]) for i in 0:n_batch]...; # dims=(1, 2)) # on GPU use only one big multiplication rather than many small ones - U = cat([i == 0 ? UP : UM for i in 0:n_batch]...; dims = (1, 2)) + U = cat([i == 0 ? UP : UM for i in 0:n_batch]...; dims=(1, 2)) #Main.@infiltrate_main σD = Diagonal(vcat(σP, vec(σMs))) Uσ = U * σD @@ -349,43 +490,50 @@ function _create_blockdiag( tmp = vcat(Uσ) end -function _create_random(rng, ::AbstractVector{T}, dims...) where {T} - rand(rng, T, dims...) +function _create_randn(rng, ::AbstractVector{T}, dims...) where {T} + randn(rng, T, dims...) end #moved to HybridVariationalInferenceCUDAExt -#function _create_random(rng, ::CUDA.CuVector{T}, dims...) where {T} +#function _create_randn(rng, ::CUDA.CuVector{T}, dims...) where {T} -""" -Compute predictions and log-Determinant of the transformation at given -transformed parameters for each site. +""" +Transform parameters and compute absolute of determinant of Jacobian of the transformation. +- from unconstrained (e.g. log) ζ scale of format (n_site x n_par) +- to constrained θ scale of format (n_site x n_par) +""" -The number of sites is given by the number of columns in `Ms`, which is determined -by the transformation, `transPMs`. +function transform_and_logjac_ζ(ζP::AbstractVector, ζMs::AbstractMatrix; + transP::Bijectors.Transform, + transMs::StackedArray=StackedArray(transM, size(ζMs, 1))) + θP, logjac_P = Bijectors.with_logabsdet_jacobian(transP, ζP) + θMs, logjac_M = Bijectors.with_logabsdet_jacobian(transMs, ζMs) + θP, θMs, logjac_P + logjac_M +end -Steps: -- transform the parameters to original constrained space -- Applies the mechanistic model for each site """ -function predict_y(ζi, xP, f, transPMs::Bijectors.Transform, - int_PMs::AbstractComponentArrayInterpreter) - θc, logjac = transform_ζ(ζi, transPMs, int_PMs) - #θc, logjac = int_PMs(ζi), eltype(ζi)(0) - y_pred_global, y_pred = f(θc.P, θc.Ms, xP) - # Main.@infiltrate_main - # @benchmark f(θc.P, θc.Ms, xP) - #y_pred_global, y_pred = f(θc.P, θc.Ms, xPg) - # TODO take care of y_pred_global - θc, y_pred, logjac +Transform parameters +- from unconstrained (e.g. log) ζ scale of format ((n_site x n_par) x n_mc) +- to constrained θ scale of the same format +""" +function transform_ζs(ζsP::AbstractMatrix, ζsMs::AbstractArray; + trans_mP::StackedArray=StackedArray(transP, n_MC), + trans_mMs::StackedArray=StackedArray(transM, n_MC * n_site_batch) +) + # transform to parameter-last that can apply transformations effectively + θsPt = trans_mP(ζsP') + θsMst = trans_mMs(permutedims(ζsMs, (3, 1, 2))) + # backtransform to n_mc last for efficient mapping? + # TODO test if faster than mapping + θsP = θsPt' + θsMs = permutedims(θsMst, (2, 3, 1)) + θsP, θsMs end -function transform_ζ(ζi, transPMs::Bijectors.Transform, - int_PMs::AbstractComponentArrayInterpreter) - # θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating - # θc = CA.ComponentVector(θtup) - #replace with more flexible transPMs after trying CUDA/Zygote - #θ, logjac = exp.(ζi), sum(ζi) - θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating - θc = int_PMs(θ) - θc, logjac +function flatten_hybrid_pars(xsP::AbstractMatrix{FT}, xsMs::AbstractArray{FT,3}) where FT + n_site_pred, n_θM, n_MC = size(xsMs) + @assert size(xsP,2) == n_MC + vcat(xsP, reshape(xsMs, n_site_pred * n_θM, n_MC)) end + + diff --git a/src/gf.jl b/src/gf.jl index e05a66b..2fefe2d 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -1,63 +1,105 @@ -function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP, args...; kwargs...) +# Point solver where ML directly predicts PBL parameters, rather than their +# distribution. + +""" +Map process base model (PBM), `f`, across each site. + +## Arguments +- `f(θ, xP, args...; intθ1, kwargs...)`: Process based model for single site + + Make sure to hint the type, so that results can be inferred. +- `θMst`: transposed model parameters across sites matrix: (n_parM, n_site_batch) +- `θP`: transposed model parameters that do not differ by site: (n_parP,) +- `θFix`: Further parameter required by f that are not calibrated. +- `xP`: Model drivers: Matrix with n_site_batch columns. + If provided a ComponentArray with labeled rows, f can then access `xP[:key]`. +- `intθ1`: ComponentArrayInterpreter that can be applied to θ, + so that entries can be extracted. + +See test_HybridProblem of using this function to construct a PBM function that +can predict across all sites. +""" +function map_f_each_site( + f, θMst::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP, args...; + intθ1::AbstractComponentArrayInterpreter, kwargs... +) # predict several sites with same global parameters θP and fixed parameters θFix - #θM, x_site = first(zip(eachcol(θMs), xP)) - yv = map(eachcol(θMs), xP) do θM, x_site - f(vcat(θP, θM, θFix), x_site, args...; kwargs...) - end + it1 = eachcol(CA.getdata(θMst)) + it2 = eachcol(xP) + _θM = first(it1) + _x_site = first(it2) + TXS = typeof(_x_site) + TY = typeof(f(vcat(θP, _θM, θFix), _x_site, args...; intθ1, kwargs...)) + #TY = typeof(f(vcat(θP, _θM, θFix), _x_site; intθ1)) + yv = map(it1, it2) do θM, x_site + x_site_typed = x_site::TXS + f(vcat(θP, θM, θFix), x_site_typed, args...; intθ1, kwargs...) + end::Vector{TY} y = stack(yv) return(y) end -function applyf(f, θMs::AbstractMatrix, θPs::AbstractMatrix, θFix::AbstractVector, xP, args...; kwargs...) - # do not call f with matrix θ, because .* with vectors S1 would go wrong - yv = map(eachcol(θMs), eachcol(θPs), xP) do θM, θP, xP_site - f(vcat(θP, θM, θFix), xP_site, args...; kwargs...) - end - y = stack(yv) - return(y) -end -#applyf(f_double, θMs_true, stack(Iterators.repeated(CA.getdata(θP_true), size(θMs_true,2)))) +# function map_f_each_site(f, θMs::AbstractMatrix, θPs::AbstractMatrix, θFix::AbstractVector, xP, args...; kwargs...) +# # do not call f with matrix θ, because .* with vectors S1 would go wrong +# yv = map(eachcol(θMs), eachcol(θPs), xP) do θM, θP, xP_site +# f(vcat(θP, θM, θFix), xP_site, args...; kwargs...) +# end +# y = stack(yv) +# return(y) +# end +# #map_f_each_site(f_double, θMs_true, stack(Iterators.repeated(CA.getdata(θP_true), size(θMs_true,2)))) """ composition f ∘ transM ∘ g: mechanistic model after machine learning parameter prediction """ -function gf(prob::AbstractHybridProblem, args...; scenario = (), kwargs...) +function gf(prob::AbstractHybridProblem; scenario = Val(()), kwargs...) train_loader = get_hybridproblem_train_dataloader(prob; scenario) train_loader_dev = gdev_hybridproblem_dataloader(train_loader; scenario) xM, xP = train_loader_dev.data[1:2] - gf(prob, xM, xP, args...; kwargs...) + gf(prob, xM, xP; scenario, kwargs...) end -function gf(prob::AbstractHybridProblem, xM::AbstractMatrix, xP::AbstractVector, args...; - scenario = (), - gdev = :use_gpu ∈ scenario ? gpu_device() : identity, +function gf(prob::AbstractHybridProblem, xM::AbstractMatrix, xP::AbstractMatrix; + scenario = Val(()), + gdev = :use_gpu ∈ _val_value(scenario) ? gpu_device() : identity, cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, - kwargs...) + is_inferred::Val{is_infer} = Val(false), + kwargs... +) where is_infer g, ϕg = get_hybridproblem_MLapplicator(prob; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - is_predict_batch = (n_batch == length(xP)) + is_predict_batch = (n_batch == size(xP,2)) n_site_pred = is_predict_batch ? n_batch : n_site - @assert length(xP) == n_site_pred + @assert size(xP, 2) == n_site_pred @assert size(xM, 2) == n_site_pred f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = !is_predict_batch) (; θP, θM) = get_hybridproblem_par_templates(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) + transMs = StackedArray(transM, n_site_pred) intP = ComponentArrayInterpreter(θP) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) pbm_covar_indices = CA.getdata(intP(1:length(intP))[pbm_covars]) ζP = inverse(transP)(θP) g_dev, ϕg_dev, ζP_dev = (gdev(g), gdev(ϕg), gdev(CA.getdata(ζP))) - gf(g_dev, transM, transP, f, xM, xP, ϕg_dev, ζP_dev, pbm_covar_indices; cdev, kwargs...) + # most of the properties of prob are not type-inferred + # hence result is not type-inferred, but may test at this context + res = is_infer ? + Test.@inferred( gf( + g_dev, transMs, transP, f, xM, xP, ϕg_dev, ζP_dev, pbm_covar_indices; + cdev, kwargs...)) : + gf(g_dev, transMs, transP, f, xM, xP, ϕg_dev, ζP_dev, pbm_covar_indices; + cdev, kwargs...) end -function gf(g::AbstractModelApplicator, transM, transP, f, xM, xP, ϕg, ζP; - cdev = identity, pbm_covars, +function gf(g::AbstractModelApplicator, transMs, transP, f, xM, xP, ϕg, ζP; + cdev, pbm_covars, intP = ComponentArrayInterpreter(ζP), kwargs...) pbm_covar_indices = intP(1:length(intP))[pbm_covars] gf(g, transM, transP, f, xM, xP, ϕg, ζP, pbm_covar_indices; kwargs...) end -function gf(g::AbstractModelApplicator, transM, transP, f, xM, xP, ϕg, ζP, pbm_covar_indices::AbstractVector{<:Integer}; - cdev = identity) +function gf(g::AbstractModelApplicator, transMs, transP, f, xM, xP, ϕg, ζP, + pbm_covar_indices::AbstractVector{<:Integer}; + cdev) # @show first(xM,5) # @show first(ϕg,5) @@ -67,7 +109,7 @@ function gf(g::AbstractModelApplicator, transM, transP, f, xM, xP, ϕg, ζP, pbm # end #xMP = _append_PBM_covars(xM, intP(ζP), pbm_covars) xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) - θMs = gtrans(g, transM, xMP, ϕg; cdev) + θMs = gtrans(g, transMs, xMP, ϕg; cdev) θP = transP(CA.getdata(ζP)) θP_cpu = cdev(θP) y_pred_global, y_pred = f(θP_cpu, θMs, xP) @@ -76,12 +118,16 @@ end """ composition transM ∘ g: transformation after machine learning parameter prediction +Provide a `transMs = StackedArray(transM, n_batch)` """ -function gtrans(g, transM, xMP, ϕg; cdev = identity) - ζMs = g(xMP, ϕg) # predict the log of the parameters +function gtrans(g, transMs, xMP::T, ϕg; cdev) where T + # TODO remove after removing gf + # predict the log of the parameters + ζMst = g(xMP, ϕg)::T # problem of Flux model applicator restructure + ζMs = ζMst' ζMs_cpu = cdev(ζMs) - # TODO move to gpu, Zygote needs to work with transM - θMs = reduce(hcat, map(transM, eachcol(ζMs_cpu))) # transform each column + θMs = transMs(ζMs_cpu) + #θMs = reduce(hcat, map(transM, eachcol(ζMs_cpu))) # transform each row end @@ -95,8 +141,8 @@ Create a loss function for given - intP: interpreter attaching axis to ζP = ϕP with components used by f - kwargs: additional keyword arguments passed to gf, such as gdev or pbm_covars -The loss function `loss_gf(p, xM, xP, y_o, y_unc, i_sites)` takes -- parameter vector p +The loss function `loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites)` takes +- parameter vector ϕ - xM: matrix of covariate, sites in the batch are in columns - xP: iteration of drivers for each site - y_o: matrix of observations, sites in columns @@ -107,27 +153,43 @@ function get_loss_gf(g, transM, transP, f, y_o_global, intϕ::AbstractComponentArrayInterpreter, intP::AbstractComponentArrayInterpreter = ComponentArrayInterpreter( intϕ(1:length(intϕ)).ϕP); - pbm_covars, kwargs...) + cdev=cpu_device(), + pbm_covars, n_site_batch, kwargs...) let g = g, transM = transM, transP = transP, f = f, y_o_global = y_o_global, intϕ = get_concrete(intϕ), + transMs = StackedArray(transM, n_site_batch), + cdev = cdev, pbm_covar_indices = CA.getdata(intP(1:length(intP))[pbm_covars]) #, intP = get_concrete(intP) #inv_transP = inverse(transP), kwargs = kwargs - function loss_gf(p, xM, xP, y_o, y_unc, i_sites) + function loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites) σ = exp.(y_unc ./ 2) - pc = intϕ(p) + ϕc = intϕ(ϕ) + # μ_ζP = ϕc.ϕP + # xMP = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices) + # ϕ_M = g(xMP, CA.getdata(ϕc.ϕg)) + # μ_ζMs = ϕ_M' + # ζP_cpu = cdev(CA.getdata(μ_ζP)) + # ζMs_cpu = cdev(CA.getdata(μ_ζMs)) + # y_pred, _, _ = apply_f_trans(ζP_cpu, ζMs_cpu, f, xP; transM, transP) y_pred_global, y_pred, θMs, θP = gf( - g, transM, transP, f, xM, xP, CA.getdata(pc.ϕg), CA.getdata(pc.ϕP), - pbm_covar_indices; kwargs...) - loss = sum(abs2, (y_pred .- y_o) ./ σ) + sum(abs2, y_pred_global .- y_o_global) - return loss, y_pred_global, y_pred, θMs, θP + g, transMs, transP, f, xM, xP, CA.getdata(ϕc.ϕg), CA.getdata(ϕc.ϕP), + pbm_covar_indices; cdev, kwargs...) + loss = sum(abs2, (y_pred .- y_o) ./ σ) #+ sum(abs2, y_pred_global .- y_o_global) + return loss, y_pred, θMs, θP end end end - -() -> begin - loss_gf(p, xM, y_o) - Zygote.gradient(x -> loss_gf(x, xM, y_o)[1], p) -end +# function tmp_fcost(is,intθ,fneglogden ) +# fcost = let is = is, intθ = intθ,fneglogden=fneglogden +# fcost_inner = (θvec, xPM, y_o, y_unc) -> begin +# θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) +# y = DoubleMM.f_doubleMM(θ, xPM, intθ) +# #y = CP.DoubleMM.f_doubleMM(θ, xPM, θpos) +# res = fneglogden(y_o, y', y_unc) +# res +# end +# end +# end diff --git a/src/hybridprobleminterpreters.jl b/src/hybridprobleminterpreters.jl new file mode 100644 index 0000000..cf39ea8 --- /dev/null +++ b/src/hybridprobleminterpreters.jl @@ -0,0 +1,63 @@ +abstract type AbstractHybridProblemInterpreters end + +struct HybridProblemInterpreters{AXP, AXM, NS, NB} <: AbstractHybridProblemInterpreters +end; + +const HPInts = HybridProblemInterpreters + +# function get_hybridproblem_statics(prob::AbstractHybridProblem, scenario) +# θP, θM = get_hybridproblem_par_templates(prob; scenario) +# NS, NB = get_hybridproblem_n_site_and_batch(prob; scenario) +# (CA.getaxes(θP), CA.getaxes(θM), NS, NB) +# end + +function HybridProblemInterpreters(prob::AbstractHybridProblem; scenario::Val) + # make sure interred get_hybridproblem_par_templates and n_site_and_n_batch + # error("'HybridProblemInterpreters(prob::AbstractHybridProblem; scenario)'", + # "is not inferred at caller level. Replace by ", + # "'HybridProblemInterpreters{get_hybridproblem_statics(prob; scenario)...}()'") + θP, θM = get_hybridproblem_par_templates(prob; scenario) + NS, NB = get_hybridproblem_n_site_and_batch(prob; scenario) + HybridProblemInterpreters{CA.getaxes(θP), CA.getaxes(θM), NS, NB}() +end + +function get_int_P(::HPInts{AXP}) where AXP + StaticComponentArrayInterpreter{AXP}() +end +function get_int_M(::HPInts{AXP,AXM}) where {AXP,AXM} + StaticComponentArrayInterpreter{AXM}() +end +function get_int_Ms_batch(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + StaticComponentArrayInterpreter(AXM, (NB,)) +end +function get_int_Mst_batch(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + StaticComponentArrayInterpreter((NB,), AXM) +end +function get_int_Ms_site(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + StaticComponentArrayInterpreter(AXM, (NS,)) +end +function get_int_Mst_site(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + StaticComponentArrayInterpreter((NS,), AXM) +end + +function get_int_PMs_batch(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + AX_MS = CA.getaxes(get_int_Ms_batch(ints)) + AX_PMs = compose_axes((;P=AXP, Ms=AX_MS)) + StaticComponentArrayInterpreter{(AX_PMs,)}() +end +function get_int_PMst_batch(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + AX_MS = CA.getaxes(get_int_Mst_batch(ints)) # note the t after Ms + AX_PMs = compose_axes((;P=AXP, Ms=AX_MS)) + StaticComponentArrayInterpreter{(AX_PMs,)}() +end +function get_int_PMs_site(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + AX_MS = CA.getaxes(get_int_Ms_site(ints)) + AX_PMs = compose_axes((;P=AXP, Ms=AX_MS)) + StaticComponentArrayInterpreter{(AX_PMs,)}() +end +function get_int_PMst_site(ints::HPInts{AXP,AXM, NS, NB}) where {AXP,AXM,NS,NB} + AX_MS = CA.getaxes(get_int_Mst_site(ints)) # note the t after Ms + AX_PMs = compose_axes((;P=AXP, Ms=AX_MS)) + StaticComponentArrayInterpreter{(AX_PMs,)}() +end + diff --git a/src/init_hybrid_params.jl b/src/init_hybrid_params.jl index dd74e07..a1567a0 100644 --- a/src/init_hybrid_params.jl +++ b/src/init_hybrid_params.jl @@ -23,7 +23,7 @@ Returns a NamedTuple of - `ϕunc0` initial uncertainty parameters, ComponentVector with format of `init_hybrid_ϕunc.` """ function init_hybrid_params(θP::AbstractVector{FT}, θM::AbstractVector{FT}, - cor_ends::NamedTuple, ϕg::AbstractVector{FT}, n_batch; + cor_ends::NamedTuple, ϕg::AbstractVector{FT}, hpints::HybridProblemInterpreters; transP = elementwise(identity), transM = elementwise(identity), ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT))) where {FT} n_θP = length(θP) @@ -39,38 +39,40 @@ function init_hybrid_params(θP::AbstractVector{FT}, θM::AbstractVector{FT}, ϕg = ϕg, unc = ϕunc0) # - get_transPMs = let transP = transP, transM = transM, n_θP = n_θP, n_θM = n_θM - function get_transPMs_inner(n_site) - transMs = ntuple(i -> transM, n_site) - ranges = vcat( - [1:n_θP], [(n_θP + i0 * n_θM) .+ (1:n_θM) for i0 in 0:(n_site - 1)]) - transPMs = Stacked((transP, transMs...), ranges) - transPMs - end - end - transPMs_batch = get_transPMs(n_batch) + # get_transPMs = let transP = transP, transM = transM, n_θP = n_θP, n_θM = n_θM + # function get_transPMs_inner(n_site) + # transMs = ntuple(i -> transM, n_site) + # ranges = vcat( + # [1:n_θP], [(n_θP + i0 * n_θM) .+ (1:n_θM) for i0 in 0:(n_site - 1)]) + # transPMs = Stacked((transP, transMs...), ranges) + # transPMs + # end + # end + get_transPMs = transPMs_batch = Val(Symbol("deprecated , use stack_ca_int(intPMs)")) + #transPMs_batch = get_transPMs(n_batch) # ranges = (P = 1:n_θP, ϕg = n_θP .+ (1:n_ϕg), unc = (n_θP + n_ϕg) .+ (1:length(ϕunc0))) # inv_trans_gu = Stacked( # (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges)) # ϕ = inv_trans_gu(CA.getdata(ϕt)) - get_ca_int_PMs = let - function get_ca_int_PMs_inner(n_site) - ComponentArrayInterpreter(CA.ComponentVector(; P = θP, - Ms = CA.ComponentMatrix( - zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site)))) - end - end + get_ca_int_PMs = Val(Symbol("deprecated , use get_int_PMst_site(HybridProblemInterpreters(prob; scenario))")) + # get_ca_int_PMs = let + # function get_ca_int_PMs_inner(n_site) + # ComponentArrayInterpreter(CA.ComponentVector(; P = θP, + # Ms = CA.ComponentMatrix( + # zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site)))) + # end + # end interpreters = map(get_concrete, (; μP_ϕg_unc = ComponentArrayInterpreter(ϕ), - PMs = get_ca_int_PMs(n_batch), + PMs = get_int_PMst_batch(hpints), unc = ComponentArrayInterpreter(ϕunc0) )) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) end """ - init_hybrid_ϕunc(cor_ends, ρ0=0f0; logσ2_logP, coef_logσ2_logMs, ρsP, ρsM) + init_hybrid_ϕunc(cor_ends, ρ0=0f0; logσ2_ζP, coef_logσ2_ζMs, ρsP, ρsM) Initialize vector of additional parameter of the approximate posterior. @@ -78,12 +80,12 @@ Arguments: - `cor_ends`: NamedTuple with entries, `P`, and `M`, respectively with integer vectors of ending columns of parameters blocks - `ρ0`: default entry for ρsP and ρsM, defaults = 0f0. -- `coef_logσ2_logM`: default column for `coef_logσ2_logMs`, defaults to `[-10.0, 0.0]` +- `coef_logσ2_logM`: default column for `coef_logσ2_ζMs`, defaults to `[-10.0, 0.0]` Returns a `ComponentVector` of -- `logσ2_logP`: vector of log-variances of ζP (on log scale). +- `logσ2_ζP`: vector of log-variances of ζP (on log scale). defaults to -10 -- `coef_logσ2_logMs`: offset and slope for the log-variances of ζM scaling with +- `coef_logσ2_ζMs`: offset and slope for the log-variances of ζM scaling with its value given by columns for each parameter in ζM, defaults to `[-10, 0]` - `ρsP` and `ρsM`: parameterization of the upper triangular cholesky factor of the correlation matrices of ζP and ζM, default to all entries `ρ0`, which defaults to zero. @@ -92,15 +94,25 @@ function init_hybrid_ϕunc( cor_ends::NamedTuple, ρ0::FT = 0.0f0, coef_logσ2_logM::AbstractVector{FT} = FT[-10.0, 0.0]; - logσ2_logP::AbstractVector{FT} = fill(FT(-10.0), cor_ends.P[end]), - coef_logσ2_logMs::AbstractMatrix{FT} = reduce( + logσ2_ζP::AbstractVector{FT} = fill(FT(-10.0), cor_ends.P[end]), + coef_logσ2_ζMs::AbstractMatrix{FT} = reduce( hcat, (coef_logσ2_logM for _ in 1:cor_ends.M[end])), ρsP = fill(ρ0, get_cor_count(cor_ends.P)), ρsM = fill(ρ0, get_cor_count(cor_ends.M)), ) where {FT} - CA.ComponentVector(; - logσ2_logP, - coef_logσ2_logMs, + nt = (; + logσ2_ζP, + coef_logσ2_ζMs, ρsP, ρsM) + ca = CA.ComponentVector(;nt...)::CA.ComponentVector end + +# macro gen_unc(nt) +# quote +# nt_ev = $(esc(nt)) +# int_nt = StaticComponentArrayInterpreter(map(x -> Val(size(x)), nt_ev)) +# int_nt(CA.getdata(CA.ComponentVector(;nt_ev...))) +# end +# end + diff --git a/src/util_ca.jl b/src/util_ca.jl index be9641f..63561da 100644 --- a/src/util_ca.jl +++ b/src/util_ca.jl @@ -10,4 +10,43 @@ function apply_preserve_axes(f, ca::CA.ComponentArray) CA.ComponentArray(f(CA.getdata(ca)), CA.getaxes(ca)) end +""" + compose_axes(axtuples::NamedTuple) + +Create a new 1d-axis that combines several other named axes-tuples +such as of `key = getaxes(::AbstractComponentArray)`. + +The new axis consists of several ViewAxes. If an axis-tuple consists only of one axis, it is used for the view. +Otherwise a ShapedAxis is created with the axes-length of the others, essentially dropping +component information that might be present in the dimensions. +""" +function compose_axes(axtuples::NamedTuple) + ls = map(axtuple -> Val(prod(axis_length.(axtuple))), axtuples) + # to work on types, need to construct value types of intervals + intervals = _construct_invervals(;lengths=ls) + named_intervals = (;zip(keys(axtuples),intervals)...) + axc = map(named_intervals, axtuples) do interval, axtuple + ax = length(axtuple) == 1 ? axtuple[1] : CA.ShapedAxis(axis_length.(axtuple)) + CA.ViewAxis(_val_value(interval), ax) + end + CA.Axis(; axc...) +end + +function _construct_invervals(;lengths) + reduce((ranges,length) -> _add_interval(;ranges, length), + Iterators.tail(lengths), init=(Val(1:_val_value(first(lengths))),)) +end +function _add_interval(;ranges, length::Val{l}) where {l} + ind_before = last(_val_value(last(ranges))) + (ranges...,Val(ind_before .+ (1:l))) +end +_val_value(::Val{x}) where x = x + + +axis_length(ax::CA.AbstractAxis) = CA.lastindex(ax) - CA.firstindex(ax) + 1 +axis_length(::CA.FlatAxis) = 0 +axis_length(ax::CA.UnitRange) = length(ax) +axis_length(ax::CA.ShapedAxis) = length(ax) +axis_length(ax::CA.Shaped1DAxis) = length(ax) + diff --git a/src/util_opt.jl b/src/util_opt.jl index f51b427..2baf742 100644 --- a/src/util_opt.jl +++ b/src/util_opt.jl @@ -21,3 +21,7 @@ callback_loss_fstate = (moditer, fstate) -> let iter = 1, moditer = moditer, fst return false end end + + + + diff --git a/test/Project.toml b/test/Project.toml index 5b17786..d1e29ea 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" diff --git a/test/runtests.jl b/test/runtests.jl index 2c87aa2..c41db77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,8 +5,12 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml if GROUP == "All" || GROUP == "Basic" #@safetestset "test" include("test/test_bijectors_utils.jl") @time @safetestset "test_bijectors_utils" include("test_bijectors_utils.jl") + #@safetestset "test" include("test/test_util_ca.jl") + @time @safetestset "test_util_ca" include("test_util_ca.jl") #@safetestset "test" include("test/test_ComponentArrayInterpreter.jl") @time @safetestset "test_ComponentArrayInterpreter" include("test_ComponentArrayInterpreter.jl") + #@safetestset "test" include("test/test_hybridprobleminterpreters.jl") + @time @safetestset "test_hybridprobleminterpreters" include("test_hybridprobleminterpreters.jl") #@safetestset "test" include("test/test_ModelApplicator.jl") @time @safetestset "test_ModelApplicator" include("test_ModelApplicator.jl") #@safetestset "test" include("test/test_gencovar.jl") diff --git a/test/test_ComponentArrayInterpreter.jl b/test/test_ComponentArrayInterpreter.jl index f3b8fda..1286bf5 100644 --- a/test/test_ComponentArrayInterpreter.jl +++ b/test/test_ComponentArrayInterpreter.jl @@ -1,23 +1,50 @@ using Test using HybridVariationalInference -using HybridVariationalInference: HybridVariationalInference as CM +using HybridVariationalInference: HybridVariationalInference as CP using ComponentArrays: ComponentArrays as CA -@testset "ComponentArrayInterpreter vector" begin - component_counts = comp_cnts = (; P=2, M=3, Unc=5) - m = ComponentArrayInterpreter(; comp_cnts...) +using MLDataDevices, GPUArraysCore +import Zygote + +# import CUDA, cuDNN +using Suppressor + +gdev = Suppressor.@suppress gpu_device() # not loaded CUDA +cdev = cpu_device() + +@testset "construct StaticComponentArrayInterepreter" begin + intv = @inferred CP.StaticComponentArrayInterpreter(CA.ComponentVector(a=1:3, b=reshape(4:9,3,2))) + ints = @inferred CP.StaticComponentArrayInterpreter((;a=Val(3), b = Val((3,2)))) + # @descend_code_warntype CP.StaticComponentArrayInterpreter((;a=Val(3), b = Val((3,2)))) + @test ints == intv +end + +@testset "ComponentArrayInterpreter cv-vector" begin + component_counts = comp_cnts = (; P=2, M=3, Unc=5) + comp_cnts_val = (; P=Val(2), M=Val(3), Unc=Val(5)) + #component_counts = comp_cnts = CA.ComponentVector(P=1:2, M=1:3, Unc=1:5) + + m = @inferred ComponentArrayInterpreter(comp_cnts) + m2 = @inferred CP.StaticComponentArrayInterpreter(comp_cnts_val) + get_positions(m) testm = (m) -> begin #type of axes may differ - #@test CM._get_ComponentArrayInterpreter_axes(m) == (CA.Axis(P=1:2, M=3:5, Unc=6:10),) + #@test CP._get_ComponentArrayInterpreter_axes(m) == (CA.Axis(P=1:2, M=3:5, Unc=6:10),) @test length(m) == 10 v = 1:length(m) cv = m(v) @test cv.Unc == 6:10 end testm(m) + #m = @inferred get_concrete(m) m = get_concrete(m) testm(get_concrete(m)) Base.isconcretetype(typeof(m)) + + cc0 = CA.ComponentVector(comp_cnts) + sum(get_positions(ComponentArrayInterpreter(cc0))) + Zygote.gradient(cc -> sum(cc), cc0) + Zygote.gradient(cc -> sum(get_positions(ComponentArrayInterpreter(cc))), cc0) end; # () -> begin @@ -32,19 +59,29 @@ end; @testset "ComponentArrayInterpreter matrix in vector" begin component_shapes = (; P=2, M=(2, 3), Unc=5) - m = ComponentArrayInterpreter(; component_shapes...) + #component_shapes = CA.ComponentVector(P=1:2, M=reshape(1:6,2, 3), Unc=1:5) + m = @inferred ComponentArrayInterpreter(component_shapes) testm = (m) -> begin @test length(m) == 13 a = 1:length(m) - cv = m(a) + cv = m isa CP.StaticComponentArrayInterpreter ? @inferred(m(a)) : m(a) @test cv.M == 2 .+ [1 3 5; 2 4 6] + cv end testm(m) - testm(get_concrete(m)) + @inferred testm(get_concrete(m)) + # test creating ComponentArrayInterpreter insite differentiated function + tmpf = (a) -> begin + m = ComponentArrayInterpreter(component_shapes) + cv = m(a) + sum(cv.M) + end + Zygote.gradient(tmpf, 1:length(m)) end; @testset "ComponentArrayInterpreter matrix and array" begin mv = ComponentArrayInterpreter(; c1=2, c2=3) + #mv = ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3)) cv = mv(1:length(mv)) n_col = 4 mm = ComponentArrayInterpreter(cv, (n_col,)) # 1-tuple @@ -85,6 +122,61 @@ end; testm(mmi) end; +@testset "stack_ca_int" begin + mv = get_concrete(ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3))) + #mv = ComponentArrayInterpreter(CA.ComponentVector(c1=1:2, c2=1:3)) + cv = mv(1:length(mv)) + n_col = 4 + n_dims = (n_col,) + mm = @inferred CP.stack_ca_int(mv, Val((n_col,))) # 1-tuple + @inferred get_positions(mm) # sizes are inferred here + testm = (m) -> begin + @test length(mm) == length(cv) * n_col + cm = mm(1:length(mm)) + #cm[:c1,:] + @test cm[:c1, 2] == 6:7 + end + testm(mm) + # + n_z = 3 + mm = @inferred stack_ca_int(mv, Val((n_col, n_z))) + testm = (m) -> begin + @test mm isa AbstractComponentArrayInterpreter + @test length(mm) == length(cv) * n_col * n_z + cm = mm(1:length(mm)) + @test cm[:c1, 2, 2] == 26:27 + end + testm(mm) + # + n_row = 3 + mm = @inferred stack_ca_int(Val((n_row,)), mv) + testm = (m) -> begin + @test mm isa AbstractComponentArrayInterpreter + @test length(mm) == n_row * length(mv) + cm = mm(1:length(mm)) + @test cm[2, :c1] == [2, 5] + end + testm(mm) + # + f_n_within = (n) -> begin + mm = @inferred stack_ca_int(Val((n,)), mv) + end + @test_broken @inferred f_n_within(3) # inferred is only + f_outer = () -> begin + f_n_within_cols = (n) -> begin + mm = @inferred stack_ca_int(mv, Val((n,))) + mm = get_concrete(ComponentArrayInterpreter(mv, (3,))) # same effects + end + # @inferred f_n_within_cols(3) # inferred is only Any + res = f_n_within_cols(3) # inferred is only + pos = @inferred get_positions(res) # but within this context size is known + @inferred res(pos) + end + #pos_outer = @inferred f_outer() # but inferred return type is Any + pos_outer = f_outer() +end; + + @testset "empty ComponentVector" begin x = CA.ComponentVector{Float32}() int1 = ComponentArrayInterpreter(x) @@ -95,5 +187,83 @@ end; @test int3 == int1 end; +@testset "compose_interpreters" begin + int1 = get_concrete(ComponentArrayInterpreter(CA.ComponentVector(x1=1:3, x2=4:5))) + int2 = get_concrete(ComponentArrayInterpreter(CA.ComponentVector(y1=1:2, y2=3:5))) + intm2 = get_concrete(ComponentArrayInterpreter(int2, (3,))) + #intc = ComponentArrayInterpreter((a=int1, b=int2)) + ints = (a=int1, b=intm2) + intc = @inferred compose_interpreters(;ints...) + # @usingany Cthulhu + # @descend_code_warntype CP.StaticComponentArrayInterpreter(a=int1, b=int2) + # @descend_code_warntype CP.compose_axes(map(x -> CA.getaxes(x), ints)) + () -> begin + nt = (a=int1, b=int2) + nt isa NamedTuple{keys, <:NTuple{N, <:AbstractComponentArrayInterpreter}} where {keys, N} + nt isa NamedTuple{keys, <:NTuple{N}} where {keys, N} + nt isa NamedTuple{keys} where {keys} + end + # + v3 = CA.ComponentVector(a = get_positions(int1), b = get_positions(intm2)) + intc2 = ComponentArrayInterpreter(v3) + @test intc == intc2 + v3r = @inferred get_concrete(intc)(CA.getdata(v3)) + @test v3r == v3 + #@usingany BenchmarkTools + #@benchmark ComponentArrayInterpreter(a=int1, b=int2) # 6 allocations? + #@benchmark CP.StaticComponentArrayInterpreter(a=int1, b=int2) # still 5 allocations? + #@benchmark CP.compose_axes((a=int1, b=int2)) # still 5 allocations? + #@usingany Cthulhu + # Cthulhu.@descend_code_typed ComponentArrayInterpreter(a=int1, b=int2) + # @code_typed get_concrete(ComponentArrayInterpreter(a=int1, b=int2)) + if gdev isa MLDataDevices.AbstractGPUDevice + vd = gdev(CA.getdata(v3)) + f1 = (v) -> begin + #intc = @inferred compose_interpreters(a=int1, b=intm2) # fails on Zygote + intc = compose_interpreters(a=int1, b=intm2) + vc = intc(v) + sum(vc.a.x1)::eltype(vc) # eltype necessary + #sum(vc.a.x1) + end + @test @inferred f1(vd) == sum(v3.a.x1) + df1 = Zygote.gradient(v -> f1(v), vd)[1]; + @test df1 isa AbstractGPUArray + end + +end; + +@testset "type inference concrete Array interpreter" begin + cai0 = ComponentArrayInterpreter(x=(3,2)) + cai = get_concrete(cai0) + v = collect(1:length(cai)) + cv = cai(v) + + cv2 = @inferred CP.tmpf(v; cv) # cai by keyword argument + #cv2 = @inferred CP.tmpf(v; cv=nothing, cai = cai0) # not inferred + cv2 = CP.tmpf(v; cv=nothing, cai = cai0) # not inferred + cv2 = @inferred CP.tmpf1(v; cai = get_concrete(cai0)) # cai by keyword argument + #cv2 = @inferred CP.tmpf1(v; cai = cai0) # inside function does not infer + cv2 = CP.tmpf1(v; cai = cai0) # get_concrete inside function does not infer outside + cv2 = @inferred CP.tmpf2(v; cai=cai0) # only when specifying return type + # () -> begin + # #cv2 = @code_warntype CP.tmpf(cai0) # Any + # #cv2 = @code_warntype CP.tmpf(cai) # ok + # cv2 = @code_warntype CP.tmpf(v;cv, cai) # ok, keywords work + # cv2 = @code_warntype CP.tmpf(v;cv, cai=cai0) # Any + # cv2 = @code_warntype CP.tmpf(v; cv) # ok !! + # cv2 = CP.tmpf(v; cv) + # typeof(cv2) + + # cv2 = CP.tmpf2(v; cai=cai) + # cv2 = @code_warntype CP.tmpf2(v; cai=cai) #ok + # cv2 = @code_warntype CP.tmpf2(v; cai=cai0) # + # cv2 = @code_warntype sum(CP.tmpf2(v; cai=cai0)) # + # cv2 = @code_warntype sum(CP.tmpf2(v; cai=cai0).x) # + # # @usingany Cthulhu + # # @descend_code_warntype CP.tmpf2(v; cai=cai0) + # # @code_warntype CP.tmpf2(v; cai=cai0) + # cv2 = CP.tmpf2(v; cai=cai0) # + # end +end diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 099ff0f..1ab7bc3 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -19,32 +19,39 @@ using Suppressor cdev = cpu_device() -#scenario = (:default,) -#scenario = (:covarK2,) +#scenario = Val((:default,)) +#scenario = Val((:covarK2,)) +#scen = CP._val_value(scenario) - -construct_problem = (;scenario=(:default,)) -> begin +function construct_problem(; scenario::Val{scen}) where scen FT = Float32 θP = CA.ComponentVector{FT}(r0=0.3, K2=2.0) θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) - transP = elementwise(exp) - transM = Stacked(elementwise(identity), elementwise(exp)) + transP = Stacked((CP.Exp(),),(1:2,)) # elementwise(exp) + transM = Stacked(identity, CP.Exp()) # test different par transforms cor_ends = (P=1:length(θP), M=[length(θM)]) # assume r0 independent of K2 int_θdoubleMM = get_concrete(ComponentArrayInterpreter( flatten1(CA.ComponentVector(; θP, θM)))) - function f_doubleMM(θ::AbstractVector, x) + function f_doubleMM(θ::AbstractVector{ET}, x; intθ1) where ET # extract parameters not depending on order, i.e whether they are in θP or θM - θc = int_θdoubleMM(θ) - r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] - y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) + local θc = intθ1(θ) + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + CA.getdata(θc[par])::ET + end + local y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) return (y) - end - function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) - #Main.@infiltrate_main - #first(eachcol(xP)) - pred_sites = applyf(f_doubleMM, θMs, θP, CA.ComponentVector{FT}(), eachcol(xP)) - pred_global = eltype(pred_sites)[] - return pred_global, pred_sites + end + f_doubleMM_sites = let intθ1 = int_θdoubleMM, f_doubleMM=f_doubleMM, + θFix = CA.ComponentVector{FT}() + function f_doubleMM_with_global_inner( + θP::AbstractVector{ET}, θMs::AbstractMatrix, xP + ) where ET + #first(eachcol(xP)) + local θMst = θMs' # map_f_each:site requires sites-last format + local pred_sites = map_f_each_site(f_doubleMM, θMst, θP, θFix, xP; intθ1) + local pred_global = eltype(pred_sites)[] + return pred_global, pred_sites + end end n_out = length(θM) rng = StableRNG(111) @@ -52,9 +59,9 @@ construct_problem = (;scenario=(:default,)) -> begin n_site, n_batch = get_hybridproblem_n_site_and_batch(CP.DoubleMM.DoubleMMCase(); scenario) # dependency on DeoubleMMCase -> take care of changes in covariates (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc - ) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase()) + ) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario) n_covar = size(xM,1) - n_input = (:covarK2 ∈ scenario) ? n_covar +1 : n_covar + n_input = (:covarK2 ∈ scen) ? n_covar +1 : n_covar g_chain = SimpleChain( static(n_input), # input dimension (optional) # dense layer with bias that maps to 8 outputs and applies `tanh` activation @@ -72,32 +79,57 @@ construct_problem = (;scenario=(:default,)) -> begin # MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites), batchsize=n_batch, partial=false) # end # end - train_dataloader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites), batchsize=n_batch, partial=false) + train_dataloader = MLUtils.DataLoader( + (xM, xP, y_o, y_unc, i_sites), batchsize=n_batch, partial=false) θall = vcat(θP, θM) - priors_dict = Dict{Symbol, Distribution}(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) + priors_dict = Dict{Symbol, Distribution}( + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) priors_dict[:r1] = fit(Normal, θall.r1, qp_uu(3 * θall.r1)) # not transformed to log-scale # scale (0,1) outputs MLmodel to normal distribution fitted to priors translated to ζ priorsM = [priors_dict[k] for k in keys(θM)] + lowers, uppers = get_quantile_transformed(priorsM, transM) app, ϕg0 = construct_ChainsApplicator(rng, g_chain) - g_chain_scaled = NormalScalingModelApplicator(app, priorsM, transM, FT) + g_chain_scaled = NormalScalingModelApplicator(app, lowers, uppers, FT) #g_chain_scaled = app ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) - pbm_covars = (:covarK2 ∈ scenario) ? (:K2,) : () + pbm_covars = (:covarK2 ∈ scen) ? (:K2,) : () HybridProblem(θP, θM, g_chain_scaled, ϕg0, ϕunc0, - f_doubleMM_with_global, f_doubleMM_with_global, priors_dict, py, + f_doubleMM_sites, f_doubleMM_sites, priors_dict, py, transM, transP, train_dataloader, n_covar, n_site, n_batch, cor_ends, pbm_covars) +end + +@testset "f_doubleMM from ProbSpec" begin + θ1 = CA.ComponentVector(r0=1.1, r1=2.1, K1=3.1, K2=4.1) + θ2 = reverse(θ1) + int_θdoubleMM = get_concrete(ComponentArrayInterpreter(θ2)) + xP1 = CA.ComponentVector(S1=[1,1,0.3,0.1], S2=[0.1,0.3,1,1,]) + n_site = 4 + xPint = ComponentArrayInterpreter((n_site,), ComponentArrayInterpreter(xP1)) + xP = xPint(repeat(xP1, outer=(1,4))) + function test_f_doubleMM(θ::AbstractVector{ET}, x; intθ1) where ET + # extract parameters not depending on order, i.e whether they are in θP or θM + θc = intθ1(θ) + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + CA.getdata(θc[par])::ET + end + r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) + end + y = @inferred test_f_doubleMM(CA.getdata(θ2), xP1; intθ1 = int_θdoubleMM) + # using ShareAdd; @usingany Cthulhu + # @descend_code_warntype test_f_doubleMM(CA.getdata(θ2), xP1) end test_without_flux = (scenario) -> begin + #scen = CP._val_value(scenario) gdev = @suppress gpu_device() prob = probc = construct_problem(;scenario); #@descend construct_problem(;scenario) - @testset "n_input and pbm_covars $(last(scenario))" begin + @testset "n_input and pbm_covars $(last(CP._val_value(scenario)))" begin g, ϕ_g = get_hybridproblem_MLapplicator(prob; scenario); - if :covarK2 ∈ scenario + if :covarK2 ∈ CP._val_value(scenario) @test g.app.m.inputdim == (static(6),) # 5 + 1 (ncovar + n_pbm) @test get_hybridproblem_pbmpar_covars(prob; scenario) == (:K2,) else @@ -106,7 +138,7 @@ test_without_flux = (scenario) -> begin end end - @testset "loss_gf $(last(scenario))" begin + @testset "loss_gf $(last(CP._val_value(scenario)))" begin #----------- fit g and θP to y_o rng = StableRNG(111) g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) @@ -125,8 +157,12 @@ test_without_flux = (scenario) -> begin # Pass the site-data for the batches as separate vectors wrapped in a tuple y_global_o = Float64[] - loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; pbm_covars) - l1 = loss_gf(p0, first(train_loader)...) + loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; + pbm_covars, n_site_batch = n_batch) + (_xM, _xP, _y_o, _y_unc, _i_sites) = first(train_loader) + l1 = @inferred ( + # @descend_code_warntype ( + loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites)) tld = first(train_loader) gr = Zygote.gradient(p -> loss_gf(p, tld...)[1], CA.getdata(p0)) @test gr[1] isa Vector @@ -146,8 +182,8 @@ test_without_flux = (scenario) -> begin end end -test_without_flux((:default,)) -test_without_flux((:covarK2,)) +test_without_flux(Val((:default,))) +test_without_flux(Val((:covarK2,))) import CUDA, cuDNN using GPUArraysCore @@ -159,7 +195,7 @@ gdev = gpu_device() test_with_flux = (scenario) -> begin prob = probc = construct_problem(;scenario); - @testset "HybridPointSolver $(last(scenario))" begin + @testset "HybridPointSolver $(last(CP._val_value(scenario)))" begin rng = StableRNG(111) solver = HybridPointSolver(; alg=Adam(0.02)) (; ϕ, resopt, probo) = solve(prob, solver; scenario, rng, @@ -169,6 +205,7 @@ test_with_flux = (scenario) -> begin maxiters=200, gdev = identity, #gpu_handler = NullGPUDataHandler + is_inferred = Val(true), ) (; θP) = get_hybridproblem_par_templates(prob; scenario) θPo = (() -> begin @@ -179,7 +216,7 @@ test_with_flux = (scenario) -> begin @test ϕ.ϕP.K2 < 1.5 * log(θP.K2) end; - @testset "HybridPosteriorSolver $(last(scenario))" begin + @testset "HybridPosteriorSolver $(last(CP._val_value(scenario)))" begin rng = StableRNG(111) solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) (; ϕ, θP, resopt) = solve(prob, solver; scenario, rng, @@ -187,7 +224,8 @@ test_with_flux = (scenario) -> begin #maxiters = 20 # too small so that it yields error maxiters=37, θmean_quant = 0.01, # test constraining mean to initial prediction - gdev = identity + gdev = identity, + is_inferred = Val(true), ) θPt = get_hybridproblem_par_templates(prob; scenario).θP @test θP.r0 < 1.5 * θPt.r0 @@ -196,9 +234,11 @@ test_with_flux = (scenario) -> begin prob.θP end; + + if gdev isa MLDataDevices.AbstractGPUDevice - @testset "HybridPosteriorSolver gpu $(last(scenario))" begin - scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0) + @testset "HybridPosteriorSolver gpu $(last(CP._val_value(scenario)))" begin + scenf = Val((CP._val_value(scenario)..., :use_Flux, :use_gpu, :omit_r0)) rng = StableRNG(111) # here using DoubleMMCase() directly rather than construct_problem #(;transP, transM) = get_hybridproblem_transforms(DoubleMM.DoubleMMCase(); scenario = scenf) @@ -210,6 +250,7 @@ test_with_flux = (scenario) -> begin maxiters = 37, # smallest value by trial and error #maxiters = 20 # too small so that it yields error θmean_quant = 0.01, # test constraining mean to initial prediction + is_inferred = Val(true), ); @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector #@test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations in test -> may fail @@ -217,9 +258,17 @@ test_with_flux = (scenario) -> begin solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, maxiters = 37, + is_inferred = Val(true), ); @test cdev(ϕ.unc.ρsM)[1] > 0 @test probo.ϕunc == cdev(ϕ.unc) + n_sample_pred = 22 + (; y, θsP, θsMs) = predict_hvi( + rng, probo; scenario = scenf, n_sample_pred, is_inferred=Val(true)); + (_xM, _xP, _y_o, _y_unc, _i_sites) = get_hybridproblem_train_dataloader(prob; scenario).data + @test size(y) == (size(_y_o)..., n_sample_pred) + @test size(θsP) == (size(probo.θP,1), n_sample_pred) + test_correlation = () -> begin n_epoch = 20 # requires (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, @@ -229,20 +278,30 @@ test_with_flux = (scenario) -> begin @test cdev(ϕ.unc.ρsM)[1] > 0 @test probo.ϕunc == cdev(ϕ.unc) # predict using problem and its associated dataloader - (; θ, y, entropy_ζ) = predict_gf(rng, probo; scenario = scenf, n_sample_pred = 200); - mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) - residθ = θ .- mean_θ - cr = cor(CA.getdata(residθ)); + n_sample_pred = 201 + (; y, θsP, θsMs) = predict_hvi(rng, probo; scenario = scenf, n_sample_pred); + # to inspect correlations among θP and θMs construct ComponentVector + hpints = HybridProblemInterpreters(prob; scenario) + int_mPMs = stack_ca_int(Val((n_sample_pred,)), get_int_PMst_site(hpints)) + θs = int_mPMs(CP.flatten_hybrid_pars(θsP, θsMs)) + mean_θ = CA.ComponentVector(vec(mean(CA.getdata(θs), dims=1)), last(CA.getaxes(θs))) + mean_θ.Ms + sd_θ = CA.ComponentVector(vec(std(CA.getdata(θs), dims=1)), last(CA.getaxes(θs))) + sd_θ.Ms + pos = get_positions(ComponentArrayInterpreter(mean_θ)) + residθs = θs .- mean_θ + + cr = cor(CA.getdata(residθs')) + pos_P = get_positions(ComponentArrayInterpreter(θs[:P,1])) i_sites = [1,2,3] - tmp = CA.ComponentArray(collect(axes(θ[:,1],1)), CA.getaxes(θ[:,1])); #ax = map(x -> axes(x,1), get_hybridproblem_par_templates(probo; scenario = scenf)) - is = vcat(tmp.P, vec(tmp.Ms[:,i_sites])) + is = vcat(pos.P, vec(pos.Ms[i_sites,:])) cr[is,is] end end; - @testset "HybridPosteriorSolver also f on gpu $(last(scenario))" begin - scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu) + @testset "HybridPosteriorSolver also f on gpu $(last(CP._val_value(scenario)))" begin + scenf = Val((CP._val_value(scenario)..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu)) rng = StableRNG(111) probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf); #prob = CP.update(probg, transM = identity, transP = identity); @@ -253,13 +312,17 @@ test_with_flux = (scenario) -> begin maxiters = 37, # smallest value by trial and error #maxiters = 20 # too small so that it yields error #θmean_quant = 0.01, # TODO make possible on gpu - cdev = identity # do not move ζ to cpu # TODO infer in solve from scenario + cdev = identity, # do not move ζ to cpu # TODO infer in solve from scenario + is_inferred = Val(true), ); @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector + n_sample_pred = 11 + (; y, θsP, θsMs) = predict_hvi( + rng, probo; scenario = scenf, n_sample_pred,is_inferred = Val(true)); # @test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations end; end # if gdev isa MLDataDevices.AbstractGPUDevice end # test_with flux -test_with_flux((:default,)) -test_with_flux((:covarK2,)) +test_with_flux(Val((:default,))) +test_with_flux(Val((:covarK2,))) diff --git a/test/test_bijectors_utils.jl b/test/test_bijectors_utils.jl index 351ee86..01a62f7 100644 --- a/test/test_bijectors_utils.jl +++ b/test/test_bijectors_utils.jl @@ -9,6 +9,8 @@ import CUDA, cuDNN using Zygote + + x = [0.1, 0.2, 0.3, 0.4] gdev = gpu_device() cdev = cpu_device() @@ -30,21 +32,21 @@ dy = Zygote.gradient(x -> trans(x,b2), x) @testset "elementwise exp" begin - ys = trans(x,b2s) + ys = @inferred trans(x,b2s) @test ys == y Zygote.gradient(x -> trans(x,b2s), x) end; @testset "Exp" begin - y1 = b3(x) - y2 = b3s(x) + y1 = @inferred b3(x) + y2 = @inferred b3s(x) @test all(inverse(b3)(y2) .≈ x) @test all(inverse(b3s)(y2) .≈ x) - ye = trans(x, b3) + ye = @inferred trans(x, b3) dye = Zygote.gradient(x -> trans(x,b3), x) @test ye == y @test dye == dy - ys = trans(x,b3s) + ys = @inferred trans(x,b3s) dys = Zygote.gradient(x -> trans(x,b2s), x) @test dys == dy end; @@ -53,21 +55,94 @@ end; if gdev isa MLDataDevices.AbstractGPUDevice xd = gdev(x) @testset "elementwise exp gpu" begin - ys = trans(xd,b2) + ys = @inferred trans(xd,b2) @test ys ≈ y @test_broken Zygote.gradient(x -> trans(x,b2), xd) @test_broken Zygote.gradient(x -> trans(x,b2s), xd) end; @testset "Exp" begin - ye = trans(xd, b3) + ye = @inferred trans(xd, b3) dye = Zygote.gradient(x -> trans(x,b3), xd) @test ye ≈ y @test all(cdev(dye) .≈ dy) - ys = trans(xd,b3s) + ys = @inferred trans(xd,b3s) dys = Zygote.gradient(x -> trans(x,b3s), xd) @test ys ≈ y @test all(cdev(dys) .≈ dy) end; end + +@testset "extend_stacked_nrow" begin + nrow = 50 # faster on CPU by factor of 20 + #nrow = 20000 # faster on GPU + X = reduce(hcat, ([x + y for x in 0:nrow] for y in 0:10:30)) + b1 = @inferred CP.Exp() + b2 = identity + b = @inferred Stacked((b1,b2), (1:1,2:size(X,2))) + bs = @inferred extend_stacked_nrow(b, size(X,1)) + Xt = @inferred reshape(bs(vec(X)), size(X)) + @test Xt[:,1] == b1(X[:,1]) + @test Xt[:,2] == b2(X[:,2]) + if gdev isa MLDataDevices.AbstractGPUDevice + Xd = gdev(X) + Xtd = @inferred reshape(bs(vec(Xd)), size(Xd)) + #Xtd2, logjac = with_logabsdet_jacobian(bs, Xd) + #@test Xtd2 == Xtd + # test transpose in gradient function + dys = Zygote.gradient(x -> sum(bs(vec(x'))), Xd')[1] + # () -> begin + # #@usingany BenchmarkTools + # @benchmark reshape(bs(vec(Xd)), size(Xd)) # macro not definedmetho + # vecXd = vec(Xd) + # @benchmark bs(vecXd) + # vecX = vec(X) + # @benchmark bs(vecX) + # Xdtrans = Xd' + # Xtrans = X' + # @benchmark Zygote.gradient(x -> sum(bs(vec(x'))), Xdtrans)[1] + # @benchmark Zygote.gradient(x -> sum(bs(vec(x'))), Xtrans)[1] + # end + end +end + +@testset "StackedArray" begin + nrow = 5 # faster on CPU by factor of 20 + #nrow = 20000 # faster on GPU + X = reduce(hcat, ([x + y for x in 0:nrow] for y in 0.0:10:30)) + b1 = @inferred CP.Exp() + b2 = identity + b = @inferred Stacked((b1,b2), (1:1,2:size(X,2))) + bs = @inferred StackedArray(b, size(X,1)) + Xt = @inferred bs(X) + @test Xt[:,1] == b1(X[:,1]) + @test Xt[:,2] == b2(X[:,2]) + X2 = @inferred inverse(bs)(Xt) + @test X2 == X + # test with Exp only + be1 = Stacked((CP.Exp(),),(1:size(X,2),)) + bse = StackedArray(be1, size(X,1)) + Xt = @inferred bse(X) # works also for adjoint + Xt2 = @inferred bse(copy(X')') # works also for adjoint + @test Xt2 == Xt + @inferred bse(X) + if gdev isa MLDataDevices.AbstractGPUDevice + Xd = gdev(X) + bse(Xd) + Xtd = @inferred bs(Xd) + Xtd2 = @inferred bs(copy(Xd')') # works also for adjoint + Xtd2 = @inferred bse(copy(Xd')') # needs copy workaround + #bse.stacked(vec(Xd')) # TODO write issue + tmpf = (X, bs) -> begin + Xt, logjac = with_logabsdet_jacobian(bs, X) + sum(Xt) .+ logjac + end + tmpf(Xd, bs) + # test transpose in gradient function + dys = Zygote.gradient(X -> tmpf(X', bs), Xd')[1] + @test all(dys[2:end,:] .== 1.0) + end +end + + diff --git a/test/test_cholesky_structure.jl b/test/test_cholesky_structure.jl index 91df685..882c0d6 100644 --- a/test/test_cholesky_structure.jl +++ b/test/test_cholesky_structure.jl @@ -8,7 +8,9 @@ using ComponentArrays: ComponentArrays as CA using GPUArraysCore: GPUArraysCore #using Flux import CUDA, cuDNN -using MLDataDevices +using MLDataDevices, Suppressor +using BlockDiagonals +using LinearAlgebra A = [1.0 2.0 3.0 2.0 1.0 4.0 @@ -23,7 +25,8 @@ C = [1.0 2.0 3.2 LC = cholesky(C).L Z = zeros(3, 3) -ggdev = gpu_device() +ggdev = Suppressor.@suppress gpu_device() +cdev = cpu_device() @testset "cholesky of blockdiagonal" begin @@ -52,6 +55,7 @@ end; @testset "invsumn" begin ns_orig = [1, 2, 3, 6] s = map(n -> sum(1:n), ns_orig) + @inferred CP.invsumn(s[1]) ns = CP.invsumn.(s) @test ns == ns_orig @test eltype(ns) == Int @@ -60,13 +64,13 @@ end; end; @testset "get_cor_count" begin - @test get_cor_count(Int[]) == 0 # case of no physical parameters - @test get_cor_count([1]) == 0 + @test @inferred get_cor_count(Int[]) == 0 # case of no physical parameters + @test @inferred get_cor_count([1]) == 0 @test get_cor_count([2]) == 1 @test get_cor_count([3]) == 3 @test get_cor_count([4]) == 6 - @test get_cor_count(4) == 6 - @test get_cor_count([1,4]) == 0 + 3 + @test @inferred get_cor_count(4) == 6 + @test @inferred get_cor_count([1,4]) == 0 + 3 @test get_cor_count([2,4]) == 1 + 1 @test get_cor_count([3,4]) == 3 + 0 @test get_cor_count([2,5]) == 1 + 3 @@ -74,11 +78,13 @@ end; @testset "vec2utri" begin v_orig = 1.0:6.0 - Uv = CP.vec2utri(v_orig) + Uv = @inferred CP.vec2utri(v_orig) + #@usingany Cthulhu + #@descend_code_warntype CP.vec2utri(v_orig) @test Uv isa UpperTriangular Zygote.gradient(v -> sum(CP.vec2utri(v)), v_orig)[1] # works nice # - v2 = CP.utri2vec(Uv) + v2 = @inferred CP.utri2vec(Uv) @test v2 == v_orig Zygote.gradient(Uv -> sum(CP.utri2vec(Uv)), Uv)[1] # works nice end; @@ -88,25 +94,25 @@ end; vcpu = collect(v_orig) n = CP.invsumn(length(v)) + 1 T = eltype(v) - U1v = CP.vec2uutri(v_orig) + U1v = @inferred CP.vec2uutri(v_orig) @test U1v isa UnitUpperTriangular @test size(U1v, 1) == 4 gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v))), vcpu)[1] # works nice # test providing keyword argument gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v; n = 4))), vcpu)[1] # works nice # - v2 = CP.uutri2vec(U1v) + v2 = @inferred CP.uutri2vec(U1v) @test v2 == v_orig gr = Zygote.gradient(U1v -> sum(CP.uutri2vec(U1v) .* (1.0:6.0)), U1v)[1] # works nice end; @testset "utri2vec_pos" begin - @test CP.utri2vec_pos(1, 1) == 1 + @test @inferred CP.utri2vec_pos(1, 1) == 1 @test CP.utri2vec_pos(1, 2) == 2 @test CP.utri2vec_pos(2, 2) == 3 @test CP.utri2vec_pos(1, 3) == 4 @test CP.utri2vec_pos(1, 4) == 7 - @test CP.utri2vec_pos(5, 5) == 15 + @test @inferred CP.utri2vec_pos(5, 5) == 15 typeof(CP.utri2vec_pos(5, 5)) == Int typeof(CP.utri2vec_pos(Int32(5), Int32(5))) == Int32 @test_throws AssertionError CP.utri2vec_pos(2, 1) @@ -114,65 +120,78 @@ end @testset "vec2uutri gpu" begin if ggdev isa MLDataDevices.AbstractGPUDevice # only run the test, if CUDA is working (not on Github ci) - v_orig = 1.0f0:6.0f0 + v_orig = 1.0f0:1.0f0:6.0f0 + #v = ggdev(v_orig) v = ggdev(collect(v_orig)) - U1v = CP.vec2uutri(v) + U1v = @inferred CP.vec2uutri(v) @test !(U1v isa UnitUpperTriangular) # on CUDA work with normal matrix @test U1v isa GPUArraysCore.AbstractGPUMatrix @test size(U1v, 1) == 4 - @test Array(U1v) == CP.vec2uutri(v_orig) + @test cdev(U1v) == CP.vec2uutri(v_orig) gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v))), v)[1] # works nice @test gr isa GPUArraysCore.AbstractGPUArray - @test Array(gr) == (1:6) .* 2.0 + @test cdev(gr) == (1:6) .* 2.0 # - v2 = CP.uutri2vec(U1v) + v2 = @inferred CP.uutri2vec(U1v) @test v2 isa GPUArraysCore.AbstractGPUVector @test eltype(v2) == eltype(U1v) - @test Array(v2) == v_orig + @test cdev(v2) == v_orig gr = Zygote.gradient(U1v -> sum(CP.uutri2vec(U1v .* 2)), U1v)[1] # works nice @test gr isa GPUArraysCore.AbstractGPUArray @test all(diag(gr) .== 0) - @test Array(CP.uutri2vec(gr)) == fill(2.0f0, length(v_orig)) + @test cdev(CP.uutri2vec(gr)) == fill(2.0f0, length(v_orig)) end end; @testset "transformU_cholesky1 gpu" begin - v_orig = 1.0f0:6.0f0 + v_orig = 1.0f0:1.0f0:6.0f0 vcpu = collect(v_orig) - ch = CP.transformU_cholesky1(vcpu) + ch = @inferred CP.transformU_cholesky1(vcpu) gr1 = Zygote.gradient(v -> sum(CP.transformU_cholesky1(v)), vcpu)[1] # works nice @test all(diag(ch' * ch) .≈ 1) if ggdev isa MLDataDevices.AbstractGPUDevice # only run the test, if CUDA is working (not on Github ci) v = ggdev(collect(v_orig)) - U1v = CP.transformU_cholesky1(v) + U1v = @inferred CP.transformU_cholesky1(v) @test !(U1v isa UnitUpperTriangular) # on CUDA work with normal matrix @test U1v isa GPUArraysCore.AbstractGPUMatrix - @test Array(U1v) ≈ ch + @test cdev(U1v) ≈ ch gr2 = Zygote.gradient(v -> sum(CP.transformU_cholesky1(v)), v)[1] # works nice - @test Array(gr2) == gr1 + @test cdev(gr2) == gr1 end end; @testset "transformU_block_cholesky1 gpu" begin - vc = CA.ComponentVector(b1 = 1.0f0:3.0f0) - vc = CA.ComponentVector(b1 = 1.0f0:3.0f0, b2 = [5.0f0]) + vc = CA.ComponentVector(b1 = 1.0f0:1.0f0:3.0f0) + vc = CA.ComponentVector(b1 = 1.0f0:1.0f0:3.0f0, b2 = [5.0f0]) v = CA.getdata(vc) - cor_ends = get_ca_ends(vc) + cor_ends = @inferred get_ca_ends(vc) #ns=(CP.invsumn(length(v[k])) + 1 for k in keys(v)) #collect(ns) ρ = collect(1f0:get_cor_count(cor_ends)) - U = CP.transformU_block_cholesky1(ρ) - U = CP.transformU_block_cholesky1(v, cor_ends) + tmp = @inferred CP.transformU_cholesky1(ρ) + _cor_counts = @inferred CP.get_cor_counts(cor_ends) # number of correlation parameters + # depr cholesky is either a BlockDiagonal of UpperTriangular of an plain UpperTriangular + # always a BlockDiagonal + # T = eltype(v) + # UT = Union{BlockDiagonals.BlockDiagonal{T, UpperTriangular{T}}, UpperTriangular{T}} + U = @inferred CP.transformU_block_cholesky1(ρ) + #U = @inferred UT CP.transformU_block_cholesky1(ρ) + #@descend_code_warntype CP.transformU_block_cholesky1(ρ) + U = @inferred CP.transformU_block_cholesky1(v, cor_ends) + #U = @inferred UT CP.transformU_block_cholesky1(v, cor_ends) + #@descend_code_warntype CP.transformU_block_cholesky1(v, cor_ends) @test diag(U' * U) ≈ ones(4) @test U[1:3, 4:4] ≈ zeros(3, 1) gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(v, cor_ends)), v)[1]; # works nice # degenerate case of no correlations vc0 = CA.ComponentVector{Float32}() - cor_ends0 = get_ca_ends(vc0) + cor_ends0 = @inferred get_ca_ends(vc0) + #@descend_code_warntype get_ca_ends(vc0) ρ0 = collect(1f0:get_cor_count(cor_ends0)) #ns=(CP.invsumn(length(v[k])) + 1 for k in keys(v)) #collect(ns) - U = CP.transformU_block_cholesky1(CA.getdata(ρ0), cor_ends0) + U = @inferred CP.transformU_block_cholesky1(CA.getdata(ρ0), cor_ends0) + #U = @inferred UT CP.transformU_block_cholesky1(CA.getdata(ρ0), cor_ends0) @test diag(U) == [1f0] gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(ρ0, cor_ends0)), v)[1]; # works nice @@ -180,19 +199,21 @@ end; vc = v_orig = CA.ComponentVector(b1 = ggdev(1.0f0:3.0f0), b2 = ggdev([5.0f0])) vc = v_orig = ggdev(CA.ComponentVector(b1 = 1.0f0:3.0f0, b2 = [5.0f0])) v = CA.getdata(vc) - cor_ends = get_ca_ends(vc) + cor_ends = @inferred get_ca_ends(vc) ρ = ggdev(collect(1f0:get_cor_count(cor_ends))) - U = CP.transformU_block_cholesky1(ρ, cor_ends) + U = @inferred CP.transformU_block_cholesky1(ρ, cor_ends) + #U = @inferred UT CP.transformU_block_cholesky1(ρ, cor_ends) @test U isa GPUArraysCore.AbstractGPUArray - @test diag(Array(U' * U)) ≈ ones(4) - @test Array(U[1:3, 4:4]) ≈ zeros(3, 1) + @test diag(cdev(U' * U)) ≈ ones(4) + @test cdev(U[1:3, 4:4]) ≈ zeros(3, 1) gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(v, cor_ends)), v)[1] # works nice # cor_ends0 = Int64[] ρ0 = ggdev(collect(1f0:get_cor_count(cor_ends0))) - U = CP.transformU_block_cholesky1(ρ0, cor_ends0) + U = @inferred CP.transformU_block_cholesky1(ρ0, cor_ends0) + #U = @inferred UT CP.transformU_block_cholesky1(ρ0, cor_ends0) @test U isa GPUArraysCore.AbstractGPUArray - @test Array(diag(U)) == [1f0] + @test cdev(diag(U)) == [1f0] end end; @@ -277,6 +298,7 @@ end S = _X * _X' # know that this is Hermitian n_x = 200 xs = rand(3, n_x) + tmp = @inferred cholesky(S) SU = cholesky(S).U σ_o = 0.05 ys_true = ysS = xs' * SU @@ -284,16 +306,18 @@ end Dσ = Diagonal(sqrt.(diag(S))) # assume given n_U = size(S, 1) - - fcost = fcostS = (Us1vec) -> begin - U = CP.transformU_cholesky1(Us1vec; n = n_U) - y_pred = (xs' * U) * Dσ - sum(abs2, ys .- y_pred) + fcost = fcostS = let n_U = n_U, xs = xs, Dσ=Dσ, ys=ys + fcost_inner = (Us1vec) -> begin + U = CP.transformU_cholesky1(Us1vec; n = n_U) + y_pred = (xs' * U) * Dσ + sum(abs2, ys .- y_pred) + end end # cannot infer true U_scaled any more Unscaled0 = S ./ diag(S) - Us1vec0 = CP.uutri2vec(Unscaled0) - fcost(Us1vec0) + Us1vec0 = @inferred CP.uutri2vec(Unscaled0) + @inferred fcost(Us1vec0) + #@descend_code_warntype fcost(Us1vec0) #fcostS(resCT.u) # cost of u optimized by Covar should yield small result if same x optf = Optimization.OptimizationFunction((x, p) -> fcost(x), Optimization.AutoZygote()) diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index f23e868..4988866 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -1,6 +1,6 @@ using Test using HybridVariationalInference -using HybridVariationalInference: HybridVariationalInference as HVI +using HybridVariationalInference: HybridVariationalInference as CP using StableRNGs using Random using Statistics @@ -22,9 +22,9 @@ gdev = gpu_device() cdev = cpu_device() prob = DoubleMM.DoubleMMCase() -scenario = (:default,) +scenario = Val((:default,)) #using Flux -#scenario = (:use_Flux,) +#scenario = Val((:use_Flux,)) par_templates = get_hybridproblem_par_templates(prob; scenario) @@ -60,33 +60,38 @@ end is = repeat((1:length(θP_true))', n_site) θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true) #xPM = map(xP1s -> repeat(xP1s', n_site), xP[1]) - xPM = (S1 = CA.getdata(xP[:S1,:])', S2 = CA.getdata(xP[:S2,:])') + xPM = (S1 = CA.getdata(xP[:S1, :])', S2 = CA.getdata(xP[:S2, :])') #θ = hcat(θP_true[is], θMs_true') intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1]))) #θpos = get_positions(intθ1) intθ = get_concrete(ComponentArrayInterpreter((n_site,), intθ1)) - fy = (θvec, xPM) -> begin - θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) - y = HVI.DoubleMM.f_doubleMM(θ, xPM, intθ) - #y = HVI.DoubleMM.f_doubleMM(θ, xPM, θpos) + # TODO replace is by ComponentArrayInterpreter + fy = let is = is, intθ = intθ + (θvec, xPM) -> begin + θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) + y = CP.DoubleMM.f_doubleMM(θ, xPM; intθ) + #y = @inferred CP.DoubleMM.f_doubleMM(θ, xPM, intθ) + # @descend_code_warntype CP.DoubleMM.f_doubleMM(θ, xPM, intθ) + #y = CP.DoubleMM.f_doubleMM(θ, xPM, θpos) + end end - y = fy(θvec, xPM) - y_exp = applyf(HVI.DoubleMM.f_doubleMM, θMs_true, θP_true, - Vector{eltype(θP_true)}(undef, 0), eachcol(xP), intθ1) + y = @inferred fy(θvec, xPM) + y_exp = map_f_each_site(CP.DoubleMM.f_doubleMM, θMs_true, θP_true, + Vector{eltype(θP_true)}(undef, 0), xP; intθ1) @test y == y_exp' ygrad = Zygote.gradient(θv -> sum(fy(θv, xPM)), θvec)[1] if gdev isa MLDataDevices.AbstractGPUDevice # θg = gdev(θ) # xPMg = gdev(xPM) - # yg = HVI.DoubleMM.f_doubleMM(θg, xPMg, intθ); + # yg = CP.DoubleMM.f_doubleMM(θg, xPMg, intθ); θvecg = gdev(θvec); # errors without ";" xPMg = gdev(xPM) - yg = fy(θvecg, xPMg) + yg = @inferred fy(θvecg, xPMg) @test cdev(yg) == y_exp' - ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1] + ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1] @test ygradg isa CA.ComponentArray @test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray - ygradgc = HVI.apply_preserve_axes(cdev, ygradg) # can print the cpu version + ygradgc = CP.apply_preserve_axes(cdev, ygradg) # can print the cpu version # ygradgc.P .- ygrad.P # ygradgc.Ms end @@ -95,33 +100,38 @@ end @testset "neg_logden_obs Matrix" begin is = repeat(axes(θP_true, 1)', n_site) θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true) - xPM = (S1 = CA.getdata(xP[:S1,:])', S2 = CA.getdata(xP[:S2,:])') + xPM = (S1 = CA.getdata(xP[:S1, :])', S2 = CA.getdata(xP[:S2, :])') #θ = hcat(θP_true[is], θMs_true') intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1]))) #θpos = get_positions(intθ1) intθ = get_concrete(ComponentArrayInterpreter((n_site,), intθ1)) - fcost = (θvec, xPM, y_o, y_unc) -> begin - θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) - y = HVI.DoubleMM.f_doubleMM(θ, xPM, intθ) - #y = HVI.DoubleMM.f_doubleMM(θ, xPM, θpos) - fneglogden(y_o, y', y_unc) + fcost = let is = is, intθ = intθ, fneglogden=fneglogden + (θvec, xPM, y_o, y_unc) -> begin + θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) + y = CP.DoubleMM.f_doubleMM(θ, xPM; intθ) + #y = CP.DoubleMM.f_doubleMM(θ, xPM, θpos) + res = fneglogden(y_o, y', y_unc) + res + end end - cost = fcost(θvec, xPM, y_o, y_unc) + #fcost = CP.tmp_fcost(is, intθ, fneglogden) + cost = @inferred fcost(θvec, xPM, y_o, y_unc) + # @descend_code_warntype fcost(θvec, xPM, y_o, y_unc) ygrad = Zygote.gradient(θv -> fcost(θv, xPM, y_o, y_unc), θvec)[1] if gdev isa MLDataDevices.AbstractGPUDevice # θg = gdev(θ) # xPMg = gdev(xPM) - # yg = HVI.DoubleMM.f_doubleMM(θg, xPMg, intθ); - θvecg = gdev(θvec); + # yg = CP.DoubleMM.f_doubleMM(θg, xPMg, intθ); + θvecg = gdev(θvec) xPMg = gdev(xPM) - y_og = gdev(y_o); + y_og = gdev(y_o) y_uncg = gdev(y_unc) costg = fcost(θvecg, xPMg, y_og, y_uncg) @test costg ≈ cost ygradg = Zygote.gradient(θv -> fcost(θv, xPMg, y_og, y_uncg), θvecg)[1]; # errors without ";" @test ygradg isa CA.ComponentArray @test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray - ygradgc = HVI.apply_preserve_axes(cdev, ygradg) # can print the cpu version + ygradgc = CP.apply_preserve_axes(cdev, ygradg) # can print the cpu version # ygradgc.P .- ygrad.P # ygradgc.Ms end @@ -130,24 +140,33 @@ end @testset "loss_g" begin g, ϕg0 = get_hybridproblem_MLapplicator(rng, prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + transMs = StackedArray(transM, n_batch) + xM_batch = xM[:, 1:n_batch] + θMs_true_tb = θMs_true[:, 1:20]' - function loss_g(ϕg, x, g, transM) - # @show first(x,5) - # @show first(ϕg,5) - ζMs = g(x, ϕg) # predict the parameters on unconstrained space - # @show first(ζMs,5) - θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column - loss = sum(abs2, θMs .- θMs_true) - return loss, θMs + loss_g = let θMs_true_tb = θMs_true_tb + function loss_g_inner(ϕg, x, g, transMs) + # @show first(x,5) + # @show first(ϕg,5) + ζMs = g(x, ϕg)' # predict the parameters on unconstrained space + # need to transpose, so that each parameter is a column -> for extend_stacked_nrow + θMs = transMs(ζMs) + # @show first(ζMs,5) + #θMs = reduce(hcat, map(transM, eachcol(ζMs_parfirst))) # transform each column + loss = sum(abs2, θMs .- θMs_true_tb) + return loss, θMs + end end - l = loss_g(ϕg0, xM, g, transM) + l = @inferred loss_g(ϕg0, xM_batch, g, transMs) @test isfinite(l[1]) - Zygote.gradient(ϕg -> loss_g(ϕg, xM, g, transM)[1], ϕg0) + Zygote.gradient(ϕg -> loss_g(ϕg, xM_batch, g, transMs)[1], ϕg0) # # actual optimization (do not need to test each time) () -> begin #histogram(ϕg0) - optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1], + optf = Optimization.OptimizationFunction( + (ϕg, p) -> loss_g(ϕg, xM_batch, g, transMs)[1], Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, ϕg0) #res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600); @@ -156,12 +175,12 @@ end ϕg_opt1 = res.u #histogram(ϕg_opt1) # all similar magnitude around zero #first(ϕg_opt1,5) - pred = loss_g(ϕg_opt1, xM, g, transM) + pred = loss_g(ϕg_opt1, xM_batch, g, transMs) θMs_pred = θMs_pred_1 = pred[2] - #scatterplot(vec(θMs_true), vec(θMs_pred)) + #scatterplot(vec(θMs_true_tb), vec(θMs_pred)) #@test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 - @test cor(θMs_true[:, 1], θMs_pred[:, 1]) > 0.9 - @test cor(θMs_true[:, 2], θMs_pred[:, 2]) > 0.9 + @test cor(θMs_true_tb[:, 1], θMs_pred[:, 1]) > 0.9 + @test cor(θMs_true_tb[:, 2], θMs_pred[:, 2]) > 0.9 end end @@ -178,7 +197,7 @@ end intϕ = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(ϕg0), ϕP = par_templates.θP)) p = p0 = vcat(ϕg0, - HVI.apply_preserve_axes(inverse(transP), par_templates.θP) .- + CP.apply_preserve_axes(inverse(transP), par_templates.θP) .- convert(eltype(ϕg0), 0.1)) # slightly disturb θP_true #p = p0 = vcat(ϕg_opt1, par_templates.θP); # almost true @@ -190,15 +209,18 @@ end pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) #loss_gf = get_loss_gf(g, transM, f, y_global_o, intϕ; gdev = identity) - loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; pbm_covars) - loss_gf2 = get_loss_gf(g, transM, transP, f2, y_global_o, intϕ; pbm_covars) - l1 = loss_gf(p0, first(train_loader)...)[1] + loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; + pbm_covars, n_site_batch = n_batch) + loss_gf2 = get_loss_gf(g, transM, transP, f2, y_global_o, intϕ; + pbm_covars, n_site_batch = n_site) + l1 = @inferred first(loss_gf(p0, first(train_loader)...)) (xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) = first(train_loader) + # @usingany Cthulhu + # @descend_code_warntype loss_gf(p0, xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) Zygote.gradient( - p0 -> loss_gf( - p0, xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch)[1], CA.getdata(p0)) - - optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], + p0 -> first(loss_gf( + p0, xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch)), CA.getdata(p0)) + optf = Optimization.OptimizationFunction((ϕ, data) -> first(loss_gf(ϕ, data...)), Optimization.AutoZygote()) optprob = OptimizationProblem(optf, CA.getdata(p0), train_loader) @@ -206,29 +228,28 @@ end #optprob, Adam(0.02), callback = callback_loss(100), maxiters = 5000); optprob, Adam(0.02), maxiters = 1000) - l1, y_pred_global, y_pred, θMs_pred, θP_pred = loss_gf2(res.u, train_loader.data...) + l1, y_pred, θMs_pred, θP_pred = loss_gf2(res.u, train_loader.data...) #l1, y_pred_global, y_pred, θMs_pred = loss_gf(p0, xM, xP, y_o, y_unc); - θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true)) + θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true')) #TODO @test isapprox(par_templates.θP, intϕ(res.u).ϕP, rtol = 0.15) #@test cor(vec(θMs_true), vec(θMs_pred)) > 0.8 - @test cor(θMs_true[:, 1], θMs_pred[:, 1]) > 0.8 - @test cor(θMs_true[:, 2], θMs_pred[:, 2]) > 0.8 + @test cor(θMs_true'[:, 1], θMs_pred[:, 1]) > 0.8 + @test cor(θMs_true'[:, 2], θMs_pred[:, 2]) > 0.8 # started from low values -> increased but not too much above true values @test all(transP(intϕ(p0).ϕP) .< θP_pred .< (1.2 .* par_templates.θP)) () -> begin #@usingany UnicodePlots - scatterplot(vec(θMs_true), vec(θMs_pred)) - scatterplot(θMs_true[1, :], θMs_pred[1, :]) - scatterplot(θMs_true[2, :], θMs_pred[2, :]) - scatterplot(log.(vec(θMs_true)), log.(vec(θMs_pred))) + scatterplot(θMs_true'[:,1], θMs_pred[:,1]) + scatterplot(θMs_true'[:,2], θMs_pred[:,2]) + scatterplot(log.(vec(θMs_true')), log.(vec(θMs_pred))) scatterplot(vec(y_pred), vec(y_o)) hcat(par_templates.θP, intϕ(p0).ϕP, intϕ(res.u).ϕP, transP(intϕ(p0).ϕP), θP_pred) end end if gdev isa MLDataDevices.AbstractGPUDevice - scenario = (:use_Flux,) + scenario = Val((:use_Flux,:use_gpu)) g, ϕg0 = get_hybridproblem_MLapplicator(rng, prob; scenario) ϕg0_gpu = gdev(ϕg0) xM_gpu = gdev(xM) diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 284c66d..4e70253 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -1,4 +1,5 @@ #using LinearAlgebra, BlockDiagonals +using LinearAlgebra using Test using HybridVariationalInference @@ -19,161 +20,413 @@ using Flux ggdev = gpu_device() - #CUDA.device!(4) rng = StableRNG(111) const prob = DoubleMM.DoubleMMCase() -scenario = (:default,) -#scenario = (:covarK2,) - +scenario = Val((:default,)) +#scenario = Val((:covarK2,)) test_scenario = (scenario) -> begin - FT = get_hybridproblem_float_type(prob; scenario) - par_templates = get_hybridproblem_par_templates(prob; scenario) - pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) + probc = HybridProblem(prob; scenario); + FT = get_hybridproblem_float_type(probc; scenario) + par_templates = get_hybridproblem_par_templates(probc; scenario) + int_P, int_M = map(ComponentArrayInterpreter, par_templates) + pbm_covars = get_hybridproblem_pbmpar_covars(probc; scenario) pbm_covar_indices = CP.get_pbm_covar_indices(par_templates.θP, pbm_covars) + #θsite_true = get_hybridproblem_par_templates(probc; scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(probc; scenario) + # note: need to use prob rather than probc here, make sure the same + rng = StableRNG(111) + (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, + y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario) + tmpf = () -> begin + # wrap inside function to not define(pollute) variables in level up + _trainloader = get_hybridproblem_train_dataloader(probc; scenario) + (_xM, _xP, _y_o, _y_unc, _i_sites) = _trainloader.data + @test _xM == xM + @test _y_o == y_o + end; tmpf() - #θsite_true = get_hybridproblem_par_templates(prob; scenario) - n_covar = 5 - n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc - ) = gen_hybridproblem_synthetic(rng, prob; scenario); - - g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario); - f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) - f_pred = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) - - n_θM, n_θP = values(map(length, get_hybridproblem_par_templates(prob; scenario))) + # prediction by g(ϕg, XM) does not correspond to θMs_true, randomly initialized + # only the magnitude is there because of NormalScaling and prior + g, ϕg0 = get_hybridproblem_MLapplicator(probc; scenario) + f = get_hybridproblem_PBmodel(probc; scenario, use_all_sites=false) + f_pred = get_hybridproblem_PBmodel(probc; scenario, use_all_sites=true) + n_θM, n_θP = values(map(length, par_templates)) py = neg_logden_indep_normal n_MC = 3 - (; transP, transM) = get_hybridproblem_transforms(prob; scenario) - cor_ends = get_hybridproblem_cor_ends(prob; scenario) + (; transP, transM) = get_hybridproblem_transforms(probc; scenario) + cor_ends = get_hybridproblem_cor_ends(probc; scenario) # transP = elementwise(exp) # transM = Stacked(elementwise(identity), elementwise(exp)) #transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) + hpints = HybridProblemInterpreters(probc; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP_true, θMs_true[:, 1], cor_ends, ϕg0, n_batch; transP, transM); + θP_true, θMs_true[:, 1], cor_ends, ϕg0, hpints; transP, transM) + int_unc = interpreters.unc + int_μP_ϕg_unc = interpreters.μP_ϕg_unc + + # @descend_code_warntype init_hybrid_params(θP_true, θMs_true[:, 1], cor_ends, ϕg0, n_batch; transP, transM) + # @descend_code_warntype CA.ComponentVector(nt) ϕ_ini = ϕ + transform_tools = nothing # TODO remove + # transform_tools = @inferred CP.setup_transform_ζ( + # transP, transM, get_int_PMst_batch(hpints)) + int_PMs = get_int_PMst_batch(hpints) if ggdev isa MLDataDevices.AbstractGPUDevice - scenario_flux = (scenario..., :use_Flux, :use_gpu) - g_flux, ϕg0_flux_cpu = get_hybridproblem_MLapplicator(prob; scenario = scenario_flux) + scenario_flux = Val((CP._val_value(scenario)..., :use_Flux, :use_gpu)) + probc_dev = HybridProblem(prob; scenario = scenario_flux); + g_flux, ϕg0_flux_cpu = get_hybridproblem_MLapplicator( + probc_dev; scenario=scenario_flux) g_gpu = ggdev(g_flux) - end; + end + + ζsP, ζsMs, σ = @inferred ( + # @descend_code_warntype ( + CP.generate_ζ( + rng, g, ϕ_ini, xM[:, 1:n_batch]; + n_MC, cor_ends, pbm_covar_indices, + int_unc=interpreters.unc, int_μP_ϕg_unc=interpreters.μP_ϕg_unc) + ) - @testset "generate_ζ" begin + @testset "generate_ζ $(last(CP._val_value(scenario)))" begin # xMtest = vcat(xM, xM[1:1,:]) # ζ, σ = CP.generate_ζ( # rng, g, ϕ_ini, xMtest[:, 1:n_batch], map(get_concrete, interpreters); # n_MC = 8, cor_ends, pbm_covar_indices) - ζ, σ = CP.generate_ζ( - rng, g, ϕ_ini, xM[:, 1:n_batch], map(get_concrete, interpreters); - n_MC = 8, cor_ends, pbm_covar_indices) - @test ζ isa Matrix + @test ζsP isa AbstractMatrix + @test ζsMs isa AbstractArray + @test size(ζsP) == (n_θP, n_MC) + @test size(ζsMs) == (n_batch, n_θM, n_MC) gr = Zygote.gradient( - # ϕ -> sum(CP.generate_ζ( - # rng, g, ϕ, xMtest[:, 1:n_batch], map(get_concrete, interpreters); - # n_MC = 8, cor_ends, pbm_covar_indices)[1]), - ϕ -> sum(CP.generate_ζ( - rng, g, ϕ, xM[:, 1:n_batch], map(get_concrete, interpreters); - n_MC = 8, cor_ends, pbm_covar_indices)[1]), - CA.getdata(ϕ_ini)) + ϕ -> begin + _ζsP, _ζsMs, _σ = CP.generate_ζ( + rng, g, ϕ, xM[:, 1:n_batch]; + n_MC=8, cor_ends, pbm_covar_indices, + int_unc=interpreters.unc, int_μP_ϕg_unc=interpreters.μP_ϕg_unc) + sum(_ζsP) + sum(_ζsMs) + sum(_σ) + end, CA.getdata(ϕ_ini)) @test gr[1] isa Vector - end; + end + + if !(:covarK2 ∈ CP._val_value(scenario)) + # can only test distribution if g is not repeated + @testset "generate_ζ check sd residuals $(last(CP._val_value(scenario)))" begin + # prescribe very different uncertainties + ϕunc_true = copy(probc.ϕunc) + sd_ζP_true = [0.2,20] + sd_ζMs_a_true = [0.1,2] # sd at_variance at θ==0 + logσ2_ζMs_b_true = [-0.3,+0.2] # slope of log_variance with θ + ρsP_true = [+0.8] + ρsM_true = [-0.6] + + ϕunc_true.logσ2_ζP = (log ∘ abs2).(sd_ζP_true) + ϕunc_true.coef_logσ2_ζMs[1,:] = (log ∘ abs2).(sd_ζMs_a_true) + ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true + ϕunc_true.ρsP = ρsP_true + ϕunc_true.ρsM = ρsM_true + probd = CP.update(probc; ϕunc=ϕunc_true); + _ϕ = vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc) + #hcat(ϕ_ini, ϕ, _ϕ)[1:4,:] + #hcat(ϕ_ini, ϕ, _ϕ)[(end-20):end,:] + n_predict = 80000 + xM_batch = xM[:, 1:n_batch] + _ζsP, _ζsMs, _σ = @inferred ( + # @descend_code_warntype ( + CP.generate_ζ( + rng, g, _ϕ, xM_batch; + n_MC = n_predict, cor_ends, pbm_covar_indices, + int_unc=interpreters.unc, int_μP_ϕg_unc=interpreters.μP_ϕg_unc) + ) + ζMs_g = g(xM_batch, probc.ϕg)' # have been generated with no scaling + function test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g) + mP = mean(_ζsP; dims=2) + residP = _ζsP .- mP + sdP = vec(std(residP; dims=2)) + _sd_ζP_true = sqrt.(exp.(ϕunc_true.logσ2_ζP)) + @test isapprox(sdP, _sd_ζP_true; rtol=0.05) + mMs = mean(_ζsMs; dims=3)[:,:,1] + hcat(mMs, ζMs_g) + # @usingany UnicodePlots + #scatterplot(ζMs_g[:,1], mMs[:,1]) + #scatterplot(ζMs_g[:,2], mMs[:,2]) + @test cor(ζMs_g[:,1], mMs[:,1]) > 0.9 + @test cor(ζMs_g[:,2], mMs[:,2]) > 0.8 + map(axes(mMs,2)) do ipar + #@show ipar + @test isapprox(mMs[:,ipar], ζMs_g[:,ipar]; rtol=0.1) + end + #ζMs_true = stack(map(inverse(transM), eachcol(CA.getdata(θMs_true[:,1:n_batch]))))' + residMs = _ζsMs .- mMs + sdMs = std(residMs; dims=3)[:,:,1] + # (_a,_b), mMi = first(zip( + # eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs))) + _sd_ζMs_true = stack(map( + eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(ζMs_g)) do (_a,_b), mMi + #eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs)) do (_a,_b), mMi + logσ2_ζM = _a .+ mMi .* _b + sqrt.(exp.(logσ2_ζM)) + end) + #ipar = 2 + #ipar = 1 + map(axes(sdMs,2)) do ipar + #@show ipar + hcat(sdMs[:,ipar], _sd_ζMs_true[:,ipar]) + @test isapprox(sdMs[:,ipar], _sd_ζMs_true[:,ipar]; rtol=0.2) + # scatterplot(sdMs[:,ipar], _sd_ζMs_true[:,ipar]) + end + i_sites_inspect = [1,2,3] + # reshape to par-first so that can inspect correlations better + residMst = permutedims(residMs[i_sites_inspect,:,:], (2,1,3)) + residPMst = vcat(residP, + reshape(residMst, size(residMst,1)*size(residMst,2), size(residMst,3))) + cor_PMs = cor(residPMst') + @test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.2 + @test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.2)) # no correlations P,M + @test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.2 + @test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.2)) # no correlations M1, M2 + @test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.2 + @test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.2)) # no correlations M1, M2 + end + test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g) + @testset "predict_hvi check sd" begin + # test if uncertainty and reshaping is propagated + # here inverse the predicted θs and then test distribution + probcu = CP.update(probc, ϕunc=ϕunc_true); + n_sample_pred = 24_000 + (; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probcu; scenario, n_sample_pred); + #size(_ζsMs), size(θsMs) + #size(_ζsP), size(θsP) + trans_minvP = StackedArray(inverse(transP), n_sample_pred) + _ζsP2 = trans_minvP(θsP) + int_minvM = StackedArray(inverse(transM), n_site) + _ζsMs2 = stack(map(eachslice(θsMs; dims=3)) do _θMs + int_minvM(_θMs) + end) + ζMs_g2 = g(xM, probcu.ϕg)' # have been generated with no scaling + test_distζ(_ζsP2, _ζsMs2, ϕunc_true, ζMs_g2) + end; + end; + end # if covar in scenario if ggdev isa MLDataDevices.AbstractGPUDevice - @testset "generate_ζ gpu" begin + @testset "generate_ζ gpu $(last(CP._val_value(scenario)))" begin ϕ = ggdev(CA.getdata(ϕ_ini)) @test g_gpu.μ isa GPUArraysCore.AbstractGPUArray + # @test g_gpu.app isa HybridVariationalInferenceFluxExt.FluxApplicator xMg_batch = ggdev(xM[:, 1:n_batch]) - ζ, σ = CP.generate_ζ( - rng, g_gpu, ϕ, xMg_batch, map(get_concrete, interpreters); - n_MC = 8, cor_ends, pbm_covar_indices) - @test ζ isa GPUArraysCore.AbstractGPUMatrix - @test eltype(ζ) == FT + ζsP_d, ζsMs_d, σ_d = @inferred ( + # @descend_code_warntype ( + CP.generate_ζ( + rng, g_gpu, ϕ, xMg_batch; + n_MC, cor_ends, pbm_covar_indices, + int_unc=interpreters.unc, int_μP_ϕg_unc=interpreters.μP_ϕg_unc)) + @test ζsP_d isa Union{GPUArraysCore.AbstractGPUMatrix, + LinearAlgebra.Adjoint{FT,<:GPUArraysCore.AbstractGPUMatrix}} + @test ζsMs_d isa Union{GPUArraysCore.AbstractGPUArray, + LinearAlgebra.Adjoint{FT,<:GPUArraysCore.AbstractGPUArray}} + @test eltype(ζsP_d) == eltype(ζsMs_d) == FT + @test size(ζsP_d) == (n_θP, n_MC) + @test size(ζsMs_d) == (n_batch, n_θM, n_MC) gr = Zygote.gradient( - ϕ -> sum(CP.generate_ζ( - rng, g_gpu, ϕ, xMg_batch, map(get_concrete, interpreters); - n_MC = 8, cor_ends, pbm_covar_indices)[1]), - ϕ) + ϕ -> begin + _ζsP, _ζsMs, _σ = CP.generate_ζ( + rng, g_gpu, ϕ, xMg_batch; + n_MC, cor_ends, pbm_covar_indices, + int_unc=interpreters.unc, int_μP_ϕg_unc=interpreters.μP_ϕg_unc) + sum(_ζsP) + sum(_ζsMs) + sum(_σ) + end, CA.getdata(ϕ)) @test gr[1] isa GPUArraysCore.AbstractGPUVector end end - @testset "neg_elbo_gtf cpu" begin + @testset "transform_and_logjac_ζ $(last(CP._val_value(scenario)))" begin + # reorder Ms columns so that first parameter of all sites is first + # # transforming entire parameter set across n_MC most efficient but + # # does not yield logdetjac + # intm_PMs_gen = get_ca_int_PMs(n_batch); + # pos_intm_PMs = get_positions(intm_PMs_gen) + # function trans_ζs_crossMC(ζs::AbstractMatrix, pos_intm_PMs::NamedTuple; n_MC = size(ζs,2)) + # ζstMs = ζs'[1:n_MC, pos_intm_PMs.Ms'] # n_MC x n_site_batch x n_par + # ζstP = ζs'[1:n_MC, pos_intm_PMs.P] # n_MC x n_par + # transPM = extend_stacked_nrow(transP, n_MC) + # θsP = reshape(transPM(vec(ζstP)), size(ζstP)) + # transMM = extend_stacked_nrow(transM, n_MC * n_batch) + # θsMs = reshape(transMM(vec(ζstMs)), size(ζstMs)) + # (θsP, θsMs) + # end + # (θsP, θsMs) = trans_ζs(ζs, pos_intm_PMs; n_MC) + # @test size(θsP) == (n_MC, n_θP) + # @test size(θsMs) == (n_MC, n_batch, n_θM) + # map by rows + ζP, ζMs = ζsP[:, 1], ζsMs[:, :, 1] + n_site_batch = size(ζMs, 1) + transMs = StackedArray(transM, n_site_batch) + θP, θMs, logjac = @inferred CP.transform_and_logjac_ζ(ζP, ζMs; transP, transMs) + @test size(θP) == (n_θP,) + @test size(θMs) == (n_site_batch, n_θM) + @test θP == transP(ζP) + @test θMs[1, :] == transM(ζMs[1, :]) + @test θMs[end, :] == transM(ζMs[end, :]) + if ggdev isa MLDataDevices.AbstractGPUDevice + ζPdev, ζMsdev = ggdev.((ζP, ζMs)) + θP, θMs, logjac = @inferred CP.transform_and_logjac_ζ( + ζPdev, ζMsdev; transP, transMs) + @test size(θP) == (n_θP,) + @test size(θMs) == (n_site_batch, n_θM) + gr = Zygote.gradient(ζPdev, ζMsdev) do ζPdev, ζMsdev + θP, θMs, logjac = CP.transform_and_logjac_ζ(ζPdev, ζMsdev; transP, transMs) + sum(θP) + sum(θMs) + logjac + end + @test eltype(gr[1]) == eltype(ζPdev) + @test eltype(gr[2]) == eltype(ζMsdev) + end + end + + @testset "transform_ζs $(last(CP._val_value(scenario)))" begin + n_site_batch, _, n_MC = size(ζsMs) + trans_mP = StackedArray(transP, n_MC) + trans_mMs = StackedArray(transM, n_MC * n_site_batch) + θsP, θsMs = @inferred CP.transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) + @test size(θsP) == (n_θP, n_MC) + @test size(θsMs) == (n_site_batch, n_θM, n_MC) + @test θsP[:, 1] == transP(ζsP[:, 1]) + @test θsP[:, end] == transP(ζsP[:, end]) + @test θsMs[1, :, 1] == transM(ζsMs[1, :, 1]) # first parameter + @test θsMs[end, :, 1] == transM(ζsMs[end, :, 1]) + @test θsMs[1, :, end] == transM(ζsMs[1, :, end]) # last parameter + @test θsMs[end, :, end] == transM(ζsMs[end, :, end]) + if ggdev isa MLDataDevices.AbstractGPUDevice + ζsPdev, ζsMsdev = ggdev.((ζsP, ζsMs)) + #trans_mP(ζsPdev) + θsP, θsMs = @inferred CP.transform_ζs(ζsPdev, ζsMsdev; trans_mP, trans_mMs) + gr = Zygote.gradient(ζsPdev, ζsMsdev) do ζsPdev, ζsMsdev + θsP, θsMs = CP.transform_ζs(ζsPdev, ζsMsdev; trans_mP, trans_mMs) + sum(θsP) + sum(θsMs) + end + @test eltype(gr[1]) == eltype(ζsPdev) + @test eltype(gr[2]) == eltype(ζsMsdev) + end + end + + @testset "neg_elbo_gtf cpu $(last(CP._val_value(scenario)))" begin i_sites = 1:n_batch - cost = neg_elbo_gtf(rng, ϕ_ini, g, transPMs_batch, f, py, - xM[:, i_sites], xP[:,i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites, - map(get_concrete, interpreters); - cor_ends, pbm_covar_indices) + transMs = StackedArray(transM, size(ζsMs, 1)) + cost = @inferred ( + #@descend_code_warntype ( + neg_elbo_gtf(rng, ϕ_ini, g, f, py, + xM[:, i_sites], xP[:, i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites; + int_unc, int_μP_ϕg_unc, + cor_ends, pbm_covar_indices, transP, transMs) + ) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_gtf(rng, ϕ, g, transPMs_batch, f, py, - xM[:, i_sites], xP[:,i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites, - map(get_concrete, interpreters); - cor_ends, pbm_covar_indices), + ϕ -> neg_elbo_gtf(rng, ϕ, g, f, py, + xM[:, i_sites], xP[:, i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites; + int_unc, int_μP_ϕg_unc, + cor_ends, pbm_covar_indices, transP, transMs), CA.getdata(ϕ_ini)) @test gr[1] isa Vector - end; + end if ggdev isa MLDataDevices.AbstractGPUDevice - @testset "neg_elbo_gtf gpu" begin + @testset "neg_elbo_gtf gpu $(last(CP._val_value(scenario)))" begin i_sites = 1:n_batch + transMs = StackedArray(transM, size(ζsMs, 1)) ϕ = ggdev(CA.getdata(ϕ_ini)) xMg_batch = ggdev(xM[:, i_sites]) - xP_batch = xP[:,i_sites] # used in f which runs on CPU - cost = neg_elbo_gtf(rng, ϕ, g_gpu, transPMs_batch, f, py, - xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites, - map(get_concrete, interpreters); - n_MC = 3, cor_ends, pbm_covar_indices) + xP_batch = xP[:, i_sites] # used in f which runs on CPU + cost = @inferred ( + #@descend_code_warntype ( + neg_elbo_gtf(rng, ϕ, g_gpu, f, py, + xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites; + int_unc, int_μP_ϕg_unc, + n_MC=3, cor_ends, pbm_covar_indices, transP, transMs) + ) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_gtf(rng, ϕ, g_gpu, transPMs_batch, f, py, - xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites, - map(get_concrete, interpreters); - n_MC = 3, cor_ends, pbm_covar_indices), + ϕ -> neg_elbo_gtf(rng, ϕ, g_gpu, f, py, + xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites; + int_unc, int_μP_ϕg_unc, + n_MC=3, cor_ends, pbm_covar_indices, transP, transMs), ϕ) @test gr[1] isa GPUArraysCore.AbstractGPUVector @test eltype(gr[1]) == FT end end - @testset "predict_gf cpu" begin - n_sample_pred = n_site = 200 - intm_PMs_gen = get_ca_int_PMs(n_site) - trans_PMs_gen = get_transPMs(n_site) - @test length(intm_PMs_gen) == 402 - @test trans_PMs_gen.length_in == 402 - (; θ, y) = predict_gf(rng, g, f_pred, ϕ_ini, xM, xP, map(get_concrete, interpreters); - get_transPMs, get_ca_int_PMs, n_sample_pred, cor_ends, pbm_covar_indices) - @test θ isa CA.ComponentMatrix - @test θ[:, 1].P.r0 > 0 + @testset "apply_f $(last(CP._val_value(scenario)))" begin + ζP = ζsP[:,1] + ζMs = ζsMs[:,:,1] + y_pred, θP_pred, θMs_pred = @inferred CP.apply_f_trans( + ζP, ζMs, f, xP[:,1:n_batch]; transP, transM) + @test size(y_pred) == size(y_o[:,1:n_batch]) + @test size(θP_pred) == (n_θP,) + @test size(θMs_pred) == (n_batch, n_θM) + # + ym_pred, θPm_pred, θMsm_pred = CP.apply_f_trans( + ζsP[:,1:1], ζMs[:,:,1:1], f, xP[:,1:n_batch]; transP, transM) + @test ym_pred[:,:,1] == y_pred + @test θPm_pred[:,1] == θP_pred + @test θMsm_pred[:,:,1] == θMs_pred + end + + @testset "predict_hvi cpu $(last(CP._val_value(scenario)))" begin + # intm_PMs_gen = get_ca_int_PMs(n_site) + # trans_PMs_gen = get_transPMs(n_site) + # @test length(intm_PMs_gen) == 402 + # @test trans_PMs_gen.length_in == 402 + n_sample_pred = 30 + (; y, θsP, θsMs, entropy_ζ) = + #Cthulhu.@descend_code_warntype ( + @inferred ( + predict_hvi(rng, g, f_pred, ϕ_ini, xM, xP; + int_μP_ϕg_unc, int_unc, + transP, transM, + n_sample_pred, cor_ends, pbm_covar_indices) + ) + @test θsP isa AbstractMatrix + @test θsMs isa AbstractArray{T,3} where {T} + int_mP = ComponentArrayInterpreter(int_P, (size(θsP, 2),)) + θsPc = int_mP(θsP) + @test all(θsPc[:r0, :] .> 0) @test y isa Array @test size(y) == (size(y_o)..., n_sample_pred) end if ggdev isa MLDataDevices.AbstractGPUDevice - @testset "predict_gf gpu" begin - n_sample_pred = 200 - ϕ = ggdev(CA.getdata(ϕ_ini)) + @testset "predict_hvi gpu $(last(CP._val_value(scenario)))" begin + n_sample_pred = 32 + ϕ_ini_g = ggdev(CA.getdata(ϕ_ini)) xMg = ggdev(xM) - (; θ, y) = predict_gf(rng, g_gpu, f_pred, ϕ, xMg, xP, map(get_concrete, interpreters); - get_transPMs, get_ca_int_PMs, n_sample_pred, cor_ends, pbm_covar_indices) - @test θ isa CA.ComponentMatrix # only ML parameters are on gpu - @test θ[:, 1].P.r0 > 0 + n_sample_pred = 30 + (; y, θsP, θsMs, entropy_ζ) = + #Cthulhu.@descend_code_warntype ( + @inferred ( + predict_hvi(rng, g_gpu, f_pred, ϕ_ini_g, xMg, xP; + int_μP_ϕg_unc, int_unc, + transP, transM, + n_sample_pred, cor_ends, pbm_covar_indices) + ) + @test θsP isa AbstractMatrix + @test θsMs isa AbstractArray{T,3} where {T} + int_mP = ComponentArrayInterpreter(int_P, (size(θsP, 2),)) + θsPc = int_mP(θsP) + @test all(θsPc[:r0, :] .> 0) @test y isa Array @test size(y) == (size(y_o)..., n_sample_pred) end - # @testset "predict_gf also f on gpu" begin + # @testset "predict_hvi also f on gpu" begin # # currently only works with identity transformations but not elementwise(exp) - # transPM_ident = get_hybridproblem_transforms(prob; scenario = (scenario..., :transIdent)) + # transPM_ident = get_hybridproblem_transforms(probc; scenario = (scenario..., :transIdent)) # get_transPMs_ident = (() -> begin # # wrap in function to not override get_transPMs # (; get_transPMs) = init_hybrid_params( @@ -184,7 +437,7 @@ test_scenario = (scenario) -> begin # n_sample_pred = 200 # ϕ = ggdev(CA.getdata(ϕ_ini)) # xMg = ggdev(xM) - # (; θ, y) = predict_gf(rng, g_gpu, f_pred, ϕ, xMg, ggdev(xP), map(get_concrete, interpreters); + # (; θ, y) = predict_hvi(rng, g_gpu, f_pred, ϕ, xMg, ggdev(xP), map(get_concrete, interpreters); # get_ca_int_PMs, n_sample_pred, cor_ends, pbm_covar_indices, # get_transPMs = get_transPMs_ident, # cdev = identity); # keep on gpu @@ -194,13 +447,12 @@ test_scenario = (scenario) -> begin # @test y isa GPUArraysCore.AbstractGPUArray # @test size(y) == (size(y_o)..., n_sample_pred) # end + end # if ggdev - end end # test_scenario -test_scenario((:default,)) - -# with providing process parameter as additional covariate -test_scenario((:covarK2,)) +test_scenario(Val((:default,))) +# with providing process parameter as additional covariate +test_scenario(Val((:covarK2,))) diff --git a/test/test_hybridprobleminterpreters.jl b/test/test_hybridprobleminterpreters.jl new file mode 100644 index 0000000..dc18c06 --- /dev/null +++ b/test/test_hybridprobleminterpreters.jl @@ -0,0 +1,80 @@ +using Test +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as CP +using ComponentArrays: ComponentArrays as CA + +using MLDataDevices, GPUArraysCore +import Zygote + +# import CUDA, cuDNN +using Suppressor + +gdev = Suppressor.@suppress gpu_device() # not loaded CUDA +cdev = cpu_device() + +scenario = Val((:default,)) +prob = DoubleMM.DoubleMMCase() + +ints = @inferred HybridProblemInterpreters(prob; scenario) +θP, θM = @inferred get_hybridproblem_par_templates(prob; scenario) +NS, NB = @inferred get_hybridproblem_n_site_and_batch(prob; scenario) + +@testset "HybridProblemInterpreters" begin + @test (@inferred get_int_P(ints)(CA.getdata(θP))) == θP + @test (@inferred get_int_M(ints)(CA.getdata(θM))) == θM + # + int_Ms_batch = get_concrete(ComponentArrayInterpreter(θM, (NB,))) + ms_vec = 1:length(int_Ms_batch) + @test (@inferred get_int_Ms_batch(ints)(ms_vec)) == int_Ms_batch(ms_vec) + int_Mst_batch = get_concrete(ComponentArrayInterpreter((NB,), θM)) + @test (@inferred get_int_Mst_batch(ints)(ms_vec)) == int_Mst_batch(ms_vec) + # + int_Ms_site = get_concrete(ComponentArrayInterpreter(θM, (NS,))) + ms_vec = 1:length(int_Ms_site) + @test (@inferred get_int_Ms_site(ints)(ms_vec)) == int_Ms_site(ms_vec) + int_Mst_site = get_concrete(ComponentArrayInterpreter((NS,), θM)) + @test (@inferred get_int_Mst_site(ints)(ms_vec)) == int_Mst_site(ms_vec) + # + pms_ca = CA.ComponentVector(P = θP, Ms = int_Ms_batch(1:length(int_Ms_batch))) + pms_vec = CA.getdata(pms_ca) + #int_PMs_batch = get_concrete(ComponentArrayInterpreter(pms_ca)) + @test (@inferred get_int_PMs_batch(ints)(pms_vec)) == pms_ca + pmst_ca = CA.ComponentVector(P = θP, Ms = int_Mst_batch(1:length(int_Mst_batch))) + pmst_vec = CA.getdata(pmst_ca) + @test (@inferred get_int_PMst_batch(ints)(pmst_vec)) == pmst_ca + # + pms_ca = CA.ComponentVector(P = θP, Ms = int_Ms_site(1:length(int_Ms_site))) + pms_vec = CA.getdata(pms_ca) + @test (@inferred get_int_PMs_site(ints)(pms_vec)) == pms_ca + pmst_ca = CA.ComponentVector(P = θP, Ms = int_Mst_site(1:length(int_Mst_site))) + pmst_vec = CA.getdata(pmst_ca) + @test (@inferred get_int_PMst_site(ints)(pmst_vec)) == pmst_ca +end; + +@testset "stack_ca_int" begin + int_Mst_batch = get_int_Mst_batch(ints) + pmst_ca = CA.ComponentVector(P = θP, Ms = int_Mst_batch(1:length(int_Mst_batch))) + n_pred = 5 + mmst_vec = repeat(CA.getdata(pmst_ca)', n_pred) # column per parameter + int_PMst_batch = @inferred get_int_PMst_batch(ints) + intm_PMst_batch = @inferred stack_ca_int(Val((n_pred,)), int_PMst_batch) + mmst = @inferred intm_PMst_batch(mmst_vec) + @test size(mmst[1, :Ms]) == (NB, length(θM)) + @test all(mmst[:, :P][:, :r0] .== pmst_ca.P.r0) + # + # note the use of Val here -> arrays interpreted will by Any outside the context + @testset "stack_ca_int not inferred outside" begin + tmpf = (mmst_vec; + intm_PMst_batch = @inferred stack_ca_int(Val((size(mmst_vec,1),)), int_PMst_batch) + ) -> begin + # good practise to help inference by providing a hint to the eltype + (@inferred intm_PMst_batch(mmst_vec))::CA.ComponentMatrix{eltype(mmst_vec),typeof(mmst_vec)} + end + res = tmpf(mmst_vec) + @test_broken @inferred tmpf(mmst_vec) + # but supplying the extended array, its inferred in this context + intm_PMst_batch2 = @inferred stack_ca_int(Val((size(mmst_vec,1),)), int_PMst_batch) + @inferred tmpf(mmst_vec; intm_PMst_batch = intm_PMst_batch2) + end +end + diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 4749c11..e16e6a4 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -7,7 +7,7 @@ using HybridVariationalInference: HybridVariationalInference as CP using StableRNGs import CUDA, cuDNN using GPUArraysCore: GPUArraysCore -using MLDataDevices +using MLDataDevices, Suppressor using Random #using SimpleChains using ComponentArrays: ComponentArrays as CA @@ -16,10 +16,11 @@ using StableRNGs #CUDA.device!(4) rng = StableRNG(111) -ggdev = gpu_device() +ggdev = Suppressor.@suppress gpu_device() +cdev = cpu_device() -const prob = DoubleMM.DoubleMMCase() -scenario = (:default,) +prob = DoubleMM.DoubleMMCase() +scenario = Val((:default,)) n_θM, n_θP = length.(values(get_hybridproblem_par_templates(prob; scenario))) @@ -27,77 +28,133 @@ n_θM, n_θP = length.(values(get_hybridproblem_par_templates(prob; scenario))) ) = gen_hybridproblem_synthetic(rng, prob; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - FT = get_hybridproblem_float_type(prob; scenario) # set to 0.02 rather than zero for debugging non-zero correlations -cor_ends = (P = 1:n_θP, M = [n_θM]) +cor_ends = (P=1:n_θP, M=[n_θM]) ρsP = zeros(FT, get_cor_count(cor_ends.P)) .+ FT(0.02) ρsM = zeros(FT, get_cor_count(cor_ends.M)) .+ FT(0.02) ϕunc = CA.ComponentVector(; - logσ2_logP = fill(FT(-10.0), n_θP), - coef_logσ2_logMs = reduce(hcat, (FT[-10.0, 0.0] for _ in 1:n_θM)), + logσ2_ζP=fill(FT(-10.0), n_θP), + coef_logσ2_ζMs=reduce(hcat, (FT[-10.0, 0.0] for _ in 1:n_θM)), ρsP, ρsM) θ_true = θ = CA.ComponentVector(; - P = θP_true, - Ms = θMs_true) + P=θP_true, + Ms=θMs_true) transPMs = elementwise(exp) # all parameters on LogNormal scale ζ_true = inverse(transPMs)(θ_true) -ϕ_true = vcat(ζ_true, CA.ComponentVector(unc = ϕunc)) -ϕ_cpu = vcat(ζ_true .+ FT(0.01), CA.ComponentVector(unc = ϕunc)) +ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc)) +ϕ_cpu = vcat(ζ_true .+ FT(0.01), CA.ComponentVector(unc=ϕunc)) -interpreters = (; pmu = ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) +interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true), + unc=ComponentArrayInterpreter(ϕ_true.unc) +) #, M=int_θM, PMs=int_θPMs) n_MC = 3 -@testset "sample_ζ_norm0 cpu" begin + +@testset "transpose_Ms_sitefirst" begin + x_true = collect(1:8) + tmp = Iterators.take(enumerate(Iterators.repeated(x_true)), n_MC) + collect(tmp) + Xt = permutedims(stack(map(tmp) do (i, x) + 10 .* i .+ x + end)) + _nP = 2; _nM = 3; _nsite = 2 + intm_PMs_parfirst = ComponentArrayInterpreter( + P = (n_MC, _nP), Ms = (n_MC, _nM, _nsite)) + Xtc = intm_PMs_parfirst(Xt) + # + X = @inferred CP.transpose_mPMs_sitefirst(Xt, _nP, _nM, _nsite, n_MC) + # using Cthulhu + # @descend_code_warntype CP.transpose_mPMs_sitefirst(Xt, _nP, _nM, _nsite, n_MC) + intm_PMs_sitefirst = ComponentArrayInterpreter( + P = (n_MC, _nP), Ms = (n_MC, _nsite, _nM)) + Xc = intm_PMs_sitefirst(X) + @test Xc.P == Xtc.P + @test Xc.Ms[:,1,:] == Xtc.Ms[:,:,1] # first site + @test Xc.Ms[:,2,:] == Xtc.Ms[:,:,2] + @test Xc.Ms[:,:,2] == Xtc.Ms[:,2,:] # second parameter +end; + +@testset "sample_ζresid_norm" begin ϕ = CA.getdata(ϕ_cpu) ϕc = interpreters.pmu(ϕ) - ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_ends) - @test size(ζ_resid) == (length(ϕc.P) + n_θM * n_site, n_MC) - gr = Zygote.gradient( - ϕc -> sum(CP.sample_ζ_norm0( - rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_ends)[1]), ϕc)[1] + ϕc.unc.coef_logσ2_ζMs[1,:] .= (log ∘ abs2).((0.1, 100.0)) + ϕc.unc.ρsM .= 0.0 + int_unc = get_concrete(ComponentArrayInterpreter(ϕc.unc)) + n_MC_pred = 300 # larger n_MC to test σ2 + n_site_batch = size(ϕc.Ms,2) + ζP_resids, ζMs_parfirst_resids, σ = @inferred CP.sample_ζresid_norm(rng, ϕc.P, ϕc.Ms, ϕc.unc; + n_MC=n_MC_pred, cor_ends, int_unc) + # ζ_resid, σ = @inferred CP.sample_ζresid_norm(rng, ϕc.P, ϕc.Ms, ϕc.unc; + # n_MC, cor_ends, int_unc = interpreters.unc) + #@usingany Cthulhu + #@descend_code_warntype CP.sample_ζresid_norm(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_ends, int_unc = get_concrete(interpreters.unc)) + #@descend_code_warntype CP.sample_ζresid_norm(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_ends, int_unc = interpreters.unc) + #@test size(ζ_resid) == (length(ϕc.P) + n_site * n_θM, n_MC) + n_θM = size(ϕc.Ms,1) + @test size(ζP_resids) == (n_θP, n_MC_pred) + @test size(ζMs_parfirst_resids) == (n_θM, n_site_batch, n_MC_pred) + gr = Zygote.gradient(ϕc -> begin + ζP_resids, ζMs_parfirst_resids, σ = CP.sample_ζresid_norm( + rng, ϕc.P, ϕc.Ms, ϕc.unc; + n_MC, cor_ends, int_unc) + sum(ζP_resids) + sum(ζMs_parfirst_resids) + end, ϕc)[1] @test length(gr) == length(ϕ) -end -# - -if ggdev isa MLDataDevices.AbstractGPUDevice - @testset "sample_ζ_norm0 gpu" begin - # sample only n_batch of 50 - n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - ϕb = CA.ComponentVector(P = ϕ_cpu.P, Ms = ϕ_cpu.Ms[:,1:n_batch], unc = ϕ_cpu.unc) - intb = ComponentArrayInterpreter(ϕb) - ϕ = ggdev(CA.getdata(ϕb)) - #tmp = ϕ[1:6] - #vec2uutri(tmp) - ϕc = intb(ϕ) - @test CA.getdata(ϕc) isa GPUArraysCore.AbstractGPUArray - #ζP, ζMs, ϕunc = ϕc.P, ϕc.Ms, ϕc.unc - #urand = CUDA.randn(length(ϕc.P) + length(ϕc.Ms), n_MC) |> gpu - #include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss - #ζ_resid, σ = sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC) - #Zygote.gradient(ϕc -> sum(sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)[1]), ϕc)[1]; - int_unc = ComponentArrayInterpreter(ϕc.unc) - ζ_resid, σ = CP.sample_ζ_norm0( - rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), CA.getdata(ϕc.unc), int_unc; - n_MC, cor_ends) - @test ζ_resid isa GPUArraysCore.AbstractGPUArray - @test size(ζ_resid) == (length(ϕc.P) + n_θM * n_batch, n_MC) - gr = Zygote.gradient( - ϕc -> sum(CP.sample_ζ_norm0( - rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), CA.getdata(ϕc.unc), int_unc; - n_MC, cor_ends)[1]), ϕc)[1]; - @test length(gr) == length(ϕ) - @test CA.getdata(gr) isa GPUArraysCore.AbstractGPUArray - Array(gr) - int_unc = ComponentArrayInterpreter(ϕc.unc) - gr2 = Zygote.gradient( - ϕc -> sum(CP.sample_ζ_norm0(rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), - CA.getdata(ϕc.unc), int_unc; n_MC, cor_ends)[1]), - ϕc)[1]; + # + n_θM, n_site_batch = size(ϕc.Ms) + # intm_PMs = ComponentArrayInterpreter( + # P = (n_MC_pred, n_θP), Ms = (n_MC_pred, n_site_batch, n_θM)) + # xc = intm_PMs(ζ_resid) + # isapprox(std(xc.Ms[:,1,1]), 0.1, rtol = 0.1) # site 1 parameter 1 + # isapprox(std(xc.Ms[:,:,1]), 0.1, rtol = 0.1) # parameter 1 + # isapprox(std(xc.Ms[:,:,2]), 100.1, rtol = 0.1) # parameter 2 + isapprox(std(ζMs_parfirst_resids[1,1,:]), 0.1, rtol = 0.1) # site 1 parameter 1 + isapprox(std(ζMs_parfirst_resids[1,:,:]), 0.1, rtol = 0.1) # parameter 1 + isapprox(std(ζMs_parfirst_resids[2,:,:]), 100.1, rtol = 0.1) # parameter 2 + + # + if ggdev isa MLDataDevices.AbstractGPUDevice + @testset "sample_ζresid_norm gpu" begin + ϕcd = CP.apply_preserve_axes(ggdev, ϕc); # semicolon necessary + @test CA.getdata(ϕcd) isa GPUArraysCore.AbstractGPUArray + #ζP, ζMs, ϕunc = ϕc.P, ϕc.Ms, ϕc.unc + #urandn = CUDA.randn(length(ϕc.P) + length(ϕc.Ms), n_MC) |> gpu + #include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss + #ζ_resid, σ = sample_ζresid_norm(urandn, ϕc.P, ϕc.Ms, ϕc.unc; n_MC) + #Zygote.gradient(ϕc -> sum(sample_ζresid_norm(urandn, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)[1]), ϕc)[1]; + ζP_resids, ζMs_parfirst_resids, σ = @inferred CP.sample_ζresid_norm( + rng, CA.getdata(ϕcd.P), CA.getdata(ϕcd.Ms), CA.getdata(ϕcd.unc); + n_MC = n_MC_pred, cor_ends, int_unc) + #@descend_code_warntype CP.sample_ζresid_norm(rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), CA.getdata(ϕc.unc); n_MC, cor_ends, int_unc) + @test ζP_resids isa GPUArraysCore.AbstractGPUArray + @test ζMs_parfirst_resids isa GPUArraysCore.AbstractGPUArray + @test size(ζP_resids) == (n_θP, n_MC_pred) + @test size(ζMs_parfirst_resids) == (n_θM, n_site_batch, n_MC_pred) + # Zygote gradient for many sites, use fewer sites here + n_site_few = 20 + ϕcd_few = CA.ComponentVector(; P = ϕcd.P, Ms = ϕcd.Ms[:,1:n_site_few], unc = ϕcd.unc); + gr = Zygote.gradient(ϕc -> begin + ζP_resids, ζMs_parfirst_resids, σ = CP.sample_ζresid_norm( + rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), CA.getdata(ϕc.unc); + n_MC, cor_ends, int_unc) + sum(ζP_resids) + sum(ζMs_parfirst_resids) + end, ϕcd_few)[1]; # semicolon required + # gr = Zygote.gradient( + # ϕc -> sum(CP.sample_ζresid_norm( + # rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), CA.getdata(ϕc.unc); + # n_MC, cor_ends, int_unc)[1]), ϕcd_few)[1]; # need semicolon + # @test CA.getdata(gr) isa GPUArraysCore.AbstractGPUArray + # CP.apply_preserve_axes(cdev, gr) + # + isapprox(std(ζMs_parfirst_resids[1,1,:]), 0.1, rtol = 0.1) # site 1 parameter 1 + isapprox(std(ζMs_parfirst_resids[1,:,:]), 0.1, rtol = 0.1) # parameter 1 + isapprox(std(ζMs_parfirst_resids[2,:,:]), 100.1, rtol = 0.1) # parameter 2 + end end end diff --git a/test/test_util_ca.jl b/test/test_util_ca.jl new file mode 100644 index 0000000..1370c9c --- /dev/null +++ b/test/test_util_ca.jl @@ -0,0 +1,19 @@ +using Test +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as CP +using ComponentArrays: ComponentArrays as CA + +@testset "compose_axes" begin + @test (@inferred CP._add_interval(;ranges=(Val(1:3),), length = Val(2))) == (Val(1:3), Val(4:5)) + ls = Val.((3,1,2)) + @test (@inferred CP._construct_invervals(;lengths=ls)) == Val.((1:3, 4:4, 5:6)) + v1 = CA.ComponentVector(A=1:3) + v2 = CA.ComponentVector(B=1:2) + v3 = CA.ComponentVector(P=(x=1, y=2), Ms=zeros(3,2)) + nt = (;C1=v1, C2=v2, C3=v3) + vt = CA.ComponentVector(; nt...) + axs = map(CA.getaxes, nt) + axc = @inferred CP.compose_axes(axs) + @test axc == CA.getaxes(vt)[1] +end +