@@ -17,8 +17,11 @@ rng = StableRNG(115)
1717scenario = NTuple {0, Symbol} ()
1818scenario = (:omit_r0 ,) # without omit_r0 ambiguous K2 estimated to high
1919scenario = (:use_Flux , :use_gpu )
20- scenario = (:use_Flux , :use_gpu , :omit_r0 )
2120scenario = (:use_Flux , :use_gpu , :omit_r0 , :few_sites )
21+ scenario = (:use_Flux , :use_gpu , :omit_r0 , :few_sites , :covarK2 )
22+ scenario = (:use_Flux , :use_gpu , :omit_r0 , :sites20 , :covarK2 )
23+ scenario = (:use_Flux , :use_gpu , :omit_r0 )
24+ scenario = (:use_Flux , :use_gpu , :omit_r0 , :covarK2 )
2225# prob = DoubleMM.DoubleMMCase()
2326
2427gdev = :use_gpu ∈ scenario ? gpu_device () : identity
@@ -48,11 +51,11 @@ solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01), n_ba
4851# solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
4952n_batches_in_epoch = n_site ÷ solver_point. n_batch
5053n_epoch = 80
51- (; ϕ, resopt) = solve (prob0, solver_point; scenario,
54+ (; ϕ, resopt, probo ) = solve (prob0, solver_point; scenario,
5255 rng, callback = callback_loss (n_batches_in_epoch * 10 ),
5356 maxiters = n_batches_in_epoch * n_epoch);
5457# update the problem with optimized parameters
55- prob0o = HVI . update (prob0; ϕg = cpu_ca (ϕ) . ϕg, θP = cpu_ca (ϕ) . θP)
58+ prob0o = probo;
5659y_pred_global, y_pred, θMs = gf (prob0o, xM, xP; scenario);
5760plt = scatterplot (θMs_true[1 , :], θMs[1 , :]);
5861lineplot! (plt, 0 , 1 )
@@ -146,20 +149,21 @@ probh = prob0o # start from point optimized to infer uncertainty
146149# probh = prob1o # start from point optimized to infer uncertainty
147150# probh = prob0 # start from no information
148151solver_post = HybridPosteriorSolver (;
149- alg = OptimizationOptimisers. Adam (0.01 ), n_batch = 50 , n_MC = 3 )
152+ alg = OptimizationOptimisers. Adam (0.01 ), n_batch = min ( 50 , n_site) , n_MC = 3 )
150153# solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
151154n_batches_in_epoch = n_site ÷ solver_post. n_batch
152155n_epoch = 80
153- (; ϕ, θP, resopt, interpreters) = solve (probh, solver_post; scenario,
154- rng, callback = callback_loss (n_batches_in_epoch * n_epoch ),
156+ (; ϕ, θP, resopt, interpreters, probo ) = solve (probh, solver_post; scenario,
157+ rng, callback = callback_loss (n_batches_in_epoch * 5 ),
155158 maxiters = n_batches_in_epoch * n_epoch,
156159 θmean_quant = 0.05 );
157- probh. get_train_loader (;n_batch = 50 , scenario)
160+ # probh.get_train_loader(;n_batch = 50, scenario)
158161# update the problem with optimized parameters, including uncertainty
159- probo = prob1o = HVI . update (prob0o; ϕg = cpu_ca (ϕ) . ϕg, θP = θP, ϕunc = cpu_ca (ϕ) . unc)
162+ prob1o = probo;
160163n_sample_pred = 400
161- (; θ, y) = predict_gf (rng, prob1o, xM, xP; scenario, n_sample_pred);
162- (θ1, y1) = (θ, y)
164+ # (; θ, y) = predict_gf(rng, prob1o, xM, xP; scenario, n_sample_pred);
165+ (; θ, y) = predict_gf (rng, prob1o; scenario, n_sample_pred);
166+ (θ1, y1) = (θ, y);
163167
164168() -> begin # prediction with fitted parameters (should be smaller than mean)
165169 y_pred_global, y_pred2, θMs = gf (prob1o, xM, xP; scenario)
@@ -170,7 +174,7 @@ n_sample_pred = 400
170174end
171175
172176# ----------- continue HVI without strong prior on θmean
173- prob2 = HVI. update (prob1o)
177+ prob2 = HVI. update (prob1o); # copy
174178function fstate_ϕunc (state)
175179 u = state. u |> cpu
176180 # Main.@infiltrate_main
@@ -179,15 +183,36 @@ function fstate_ϕunc(state)
179183end
180184n_epoch = 100
181185# n_epoch = 400
182- (; ϕ, θP, resopt, interpreters) = solve (prob2, HVI. update (solver_post, n_MC = 12 );
186+ (; ϕ, θP, resopt, interpreters, probo) = solve (prob2,
187+ HVI. update (solver_post, n_MC = 12 );
188+ # HVI.update(solver_post, n_MC = 30);
183189 scenario, rng, maxiters = n_batches_in_epoch * n_epoch,
184190 callback = HVI. callback_loss_fstate (n_batches_in_epoch* 5 , fstate_ϕunc));
185- probo = prob2o = HVI . update (prob2; ϕg = cpu_ca (ϕ) . ϕg, θP = θP, ϕunc = cpu_ca (ϕ) . unc) ;
191+ prob2o = probo ;
186192
187193() -> begin
188194 using JLD2
189- fname_probos = " intermediate/probos.jld2"
195+ # fname_probos = "intermediate/probos_$(last(scenario)).jld2"
196+ fname_probos = " intermediate/probos800_$(last (scenario)) .jld2"
190197 JLD2. save (fname_probos, Dict (" prob1o" => prob1o, " prob2o" => prob2o))
198+ tmp = JLD2. load (fname_probos)
199+ # get_train_loader function could not be restored with JLD2
200+ prob1o = HVI. update (tmp[" prob1o" ], get_train_loader = prob0. get_train_loader);
201+ prob2o = HVI. update (tmp[" prob2o" ], get_train_loader = prob0. get_train_loader);
202+ end
203+
204+ () -> begin # load the non-covar scenario
205+ using JLD2
206+ # fname_probos = "intermediate/probos_$(last(scenario)).jld2"
207+ fname_probos = " intermediate/probos800_omit_r0.jld2"
208+ tmp = JLD2. load (fname_probos)
209+ # get_train_loader function could not be restored with JLD2
210+ prob1o_indep = HVI. update (tmp[" prob1o" ], get_train_loader = prob0. get_train_loader);
211+ prob2o_indep = HVI. update (tmp[" prob2o" ], get_train_loader = prob0. get_train_loader);
212+ # test predicting correct obs-uncertainty of predictive posterior
213+ n_sample_pred = 400
214+ (; θ, y, entropy_ζ) = predict_gf (rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
215+ (θ2_indep, y2_indep) = (θ, y)
191216end
192217
193218() -> begin # otpimize using LUX
@@ -221,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :])
221246
222247# test predicting correct obs-uncertainty of predictive posterior
223248n_sample_pred = 400
224- (; θ, y, entropy_ζ) = predict_gf (rng, probo , xM, xP; scenario, n_sample_pred);
249+ (; θ, y, entropy_ζ) = predict_gf (rng, prob2o , xM, xP; scenario, n_sample_pred);
225250(θ2, y2) = (θ, y)
226251size (y) # n_obs x n_site, n_sample_pred
227252size (θ) # n_θP + n_site * n_θM x n_sample
@@ -243,6 +268,7 @@ mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ
243268plt = scatterplot (θMs_true[1 , :], mean_θ. Ms[1 , :]);
244269lineplot! (plt, 0 , 1 )
245270plt = scatterplot (θMs_true[2 , :], mean_θ. Ms[2 , :])
271+ histogram (θ[:P ,:])
246272# scatter(fig[1,1], CA.getdata(θMs_true[1, :]), CA.getdata(mean_θ.Ms[1, :])); ablines!(fig[1,1], 0, 1)
247273# @usingany AlgebraOfGraphices
248274# fig = Figure()
332358 fig = Figure (; size = (640 , 480 ))
333359 fig = Figure (; size = (320 , 240 ))
334360 gp = fig[1 , 1 ]
335- f = draw! (gp, plt + plth)
361+ fd = draw! (gp, plt + plth)
336362 legend! (
337- gp, f ; tellwidth = false , halign = :right , valign = :top , margin = (10 , 10 , 10 , 10 ))
363+ gp, fd ; tellwidth = false , halign = :right , valign = :top , margin = (10 , 10 , 10 , 10 ))
338364 save (" r1_density.pdf" , fig)
339365 save (" tmp.svg" , fig)
340366
349375 plth = mapping ([0.0 ]) * visual (VLines; linestyle = :dash )
350376 fig = Figure (; size = (640 , 480 ))
351377 gp = fig[1 , 1 ]
352- f = draw! (gp, plt + plth)
378+ fg = draw! (gp, plt + plth)
353379 legend! (
354- gp, f ; tellwidth = false , halign = :right , valign = :top , margin = (10 , 10 , 10 , 10 ))
380+ gp, fg ; tellwidth = false , halign = :right , valign = :top , margin = (10 , 10 , 10 , 10 ))
355381 save (" ys_density.pdf" , fig)
356382 save (" tmp.svg" , fig)
357383
362388 aog. density ()
363389 fig = Figure ()
364390 gp = fig[1 , 1 ]
365- f = draw! (gp, plt)
391+ fg = draw! (gp, plt)
366392 legend! (
367- gp, f ; tellwidth = false , halign = :right , valign = :top , margin = (10 , 10 , 10 , 10 ))
393+ gp, fg ; tellwidth = false , halign = :right , valign = :top , margin = (10 , 10 , 10 , 10 ))
368394 save (" negLogDensity.pdf" , fig)
369395 save (" tmp.svg" , fig)
370396end
@@ -444,7 +470,9 @@ g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario)
444470 θP, θM, cor_ends, ϕg0, n_site; transP, transM, ϕunc0);
445471
446472intm_PMs_gen = get_ca_int_PMs (n_site);
473+ # intm_PMs_gen = get_ca_int_PMs(100);
447474trans_PMs_gen = get_transPMs (n_site);
475+ # trans_PMs_gen = get_transPMs(100);
448476
449477"""
450478ζMs in chain are all first parameter, all second parameters, ...
465493# mle_estimate.values
466494
467495# takes ~ 25 minutes
468- n_sample_NUTS = 800
496+ # n_sample_NUTS = 800
497+ n_sample_NUTS = 2000
469498# chain = sample(model, NUTS(), n_sample_NUTS, initial_params = ζ0_true .+ 0.001)
470499# n_sample_NUTS = 24
471500n_threads = 8
@@ -475,10 +504,15 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread
475504
476505() -> begin
477506 using JLD2
478- jldsave (" intermediate/doubleMM_chain_zeta.jld2" , false , IOStream; chain)
479- chain = load (" intermediate/doubleMM_chain_zeta.jld2" , " chain" ; iotype = IOStream)
507+ fname = " intermediate/doubleMM_chain_zeta_$(last (scenario)) .jld2"
508+ jldsave (fname, false , IOStream; chain)
509+ chain = load (fname, " chain" ; iotype = IOStream)
480510end
481511
512+ # ζi = first(eachrow(Array(chain)))
513+ ζs = mapreduce (ζi -> transposeMs (ζi, intm_PMs_gen, true ), hcat, eachrow (Array (chain)));
514+ (; θ, y) = HVI. predict_ζf (ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
515+ (ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
482516
483517
484518() -> begin # check that the model predicts the same as HVI-code
@@ -494,21 +528,18 @@ end
494528
495529() -> begin # plot chain
496530 # @usingany TwMakieHelpers, CairoMakie
497- ch = chain[:,vcat (1 : 2 ,n_θP+ n_site+ 1 ),:];
531+ # θP and first θMs
532+ ch = chain[:,vcat (1 : n_θP, n_θP+ 1 , n_θP+ n_site+ 1 ),:];
498533 fig = plot_chn (ch)
499534 save (" tmp.svg" , fig)
500535end
501536
502- # ζi = first(eachrow(Array(chain)))
503- ζs = mapreduce (ζi -> transposeMs (ζi, intm_PMs_gen, true ), hcat, eachrow (Array (chain)));
504- (; θ, y) = HVI. predict_ζf (ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
505- (ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
506537
507538mean_y_invζ = mean_y_hmc = map (mean, eachslice (y_hmc; dims = (1 , 2 )));
508539# describe(mean_y_pred - y_o)
509540histogram (vec (mean_y_invζ) - vec (y_true)) # predictions centered around y_o (or y_true)
510541plt = scatterplot (vec (y_true), vec (mean_y_invζ));
511- lineplot! (plt, 0 , 2 )
542+ lineplot! (plt, 0 , 1 )
512543mean (mean_y_invζ - y_true) # still ok
513544
514545# first site, first prediction
@@ -531,17 +562,25 @@ lineplot!(plt, 0, 1)
531562() -> begin # compare against HVI sample
532563 # @usingany AlgebraOfGraphics, TwPrototypes, CairoMakie, DataFrames
533564 makie_config = ppt_MakieConfig ()
565+ function get_fig_size (cfg; width2height= golden_ratio, xfac= 1.0 )
566+ cfg = makie_config
567+ x_inch = first (cfg. size_inches) * xfac
568+ y_inch = x_inch / width2height
569+ 72 .* (x_inch, y_inch) ./ cfg. pt_per_unit # size_pt
570+ end
571+
534572 ζs_hvi = log .(θ2)
573+ ζs_hvi_indep = log .(θ2_indep)
535574 int_pms = interpreters. PMs
536575 par_pos = int_pms (1 : length (int_pms))
537576 i_sites = 1 : 10
538577 # i_sites = 6:10
539578 # i_sites = 11:15
540- scen = vcat (fill (:hvi ,size (ζs_hvi,2 )),fill (:hmc ,size (ζs_hmc,2 )))
579+ scen = vcat (fill (:hvi ,size (ζs_hvi,2 )),fill (:hmc ,size (ζs_hmc,2 )), fill ( :hvi_indep , size (ζs_hvi_indep, 2 )) )
541580 dfP = mapreduce (vcat, axes (θP,1 )) do i_par
542581 pos = par_pos. P[i_par]
543582 DataFrame (
544- value = vcat (ζs_hvi[pos,:], ζs_hmc[pos,:]),
583+ value = vcat (ζs_hvi[pos,:], ζs_hmc[pos,:], ζs_hvi_indep[pos,:] ),
545584 variable = keys (θP)[i_par],
546585 site = i_sites[1 ],
547586 Method = scen
@@ -551,7 +590,7 @@ lineplot!(plt, 0, 1)
551590 mapreduce (vcat, axes (θM,1 )) do i_par
552591 pos = par_pos. Ms[i_par, i_site]
553592 DataFrame (
554- value = vcat (ζs_hvi[pos,:], ζs_hmc[pos,:]),
593+ value = vcat (ζs_hvi[pos,:], ζs_hmc[pos,:], ζs_hvi_indep [pos,:]),
555594 variable = keys (θM)[i_par],
556595 site = i_site,
557596 Method = scen
@@ -579,14 +618,25 @@ lineplot!(plt, 0, 1)
579618 end
580619 end
581620 ) # vcat
582- plt = (data (df) * mapping (color= :Method ) * density (datalimits= extrema) +
583- data (df_true) * visual (VLines; color= :blue , linestyle= :dash )) *
584- mapping (:value => " " , col= :variable => sorter (vcat (keys (θP)... , keys (θM)... )), row = (:site => nonnumeric))
621+ # cf90 = (x) -> quantile(x, [0.05,0.95])
622+ plt = (data (subset (df, :Method => ByRow (∈ ((:hvi ,:hmc ))))) * mapping (:value => (x -> x ) => " " , color= :Method ) * AlgebraOfGraphics. density (datalimits= extrema) +
623+ data (df_true) * mapping (:value ) * visual (VLines; color= :blue , linestyle= :dash )) *
624+ mapping (col= :variable => sorter (vcat (keys (θP)... , keys (θM)... )), row = (:site => nonnumeric))
585625 fig = Figure (size = get_fig_size (makie_config, xfac= 1 , width2height = 1 / 2 ));
586- f = draw! (fig, plt, facet= (; linkxaxes= :minimal , linkyaxes= :none ,), axis= (xlabelvisible= false ,));
626+ fg = draw! (fig, plt, facet= (; linkxaxes= :minimal , linkyaxes= :none ,), axis= (xlabelvisible= false ,));
587627 fig
588628 save (" tmp.svg" , fig)
589- save_with_config (" intermediate/compare_hmc_hvi_sites" , fig; makie_config)
629+ save_with_config (" intermediate/compare_hmc_hvi_sites_$(last (scenario)) " , fig; makie_config)
630+
631+ plt = (data (subset (df, :Method => ByRow (∈ ((:hvi , :hvi_indep ))))) * mapping (:value => (x -> x ) => " " , color= :Method ) * AlgebraOfGraphics. density (datalimits= extrema) +
632+ data (df_true) * mapping (:value ) * visual (VLines; color= :blue , linestyle= :dash )) *
633+ mapping (col= :variable => sorter (vcat (keys (θP)... , keys (θM)... )), row = (:site => nonnumeric))
634+ fig = Figure (size = get_fig_size (makie_config, xfac= 1 , width2height = 1 / 2 ));
635+ fg = draw! (fig, plt, facet= (; linkxaxes= :minimal , linkyaxes= :none ,), axis= (xlabelvisible= false ,));
636+ fig
637+ save (" tmp.svg" , fig)
638+ save_with_config (" intermediate/compare_hvi_indep_sites_$(last (scenario)) " , fig; makie_config)
639+
590640 #
591641 # compare density of predictions
592642 y_hvi = y2
@@ -633,24 +683,18 @@ lineplot!(plt, 0, 1)
633683 end
634684 end
635685
636- plt = (data (dfy) * mapping (color= :Method ) * density (datalimits= extrema) +
686+ plt = (data (dfy) * mapping (color= :Method ) * AlgebraOfGraphics . density (datalimits= extrema) +
637687 data (dfyt) * mapping (color= :Method ) * visual (VLines; linestyle= :dash )) *
638688 # data(dfyt) * mapping(color=:Method, linestyle=:Method) * visual(VLines; linestyle=:dash)) *
639689 mapping (:value => " " , col= :i_obs => nonnumeric, row = :site => nonnumeric)
640690
641- function get_fig_size (cfg; width2height= golden_ratio, xfac= 1.0 )
642- cfg = makie_config
643- x_inch = first (cfg. size_inches) * xfac
644- y_inch = x_inch / width2height
645- 72 .* (x_inch, y_inch) ./ cfg. pt_per_unit # size_pt
646- end
647691 fig = Figure (size = get_fig_size (makie_config, xfac= 1 , width2height = 1 / 2 ));
648692 f = draw! (fig, plt,
649693 facet= (; linkxaxes= :minimal , linkyaxes= :none ,),
650694 axis= (xlabelvisible= false ,yticklabelsvisible= false ));
651695 legend! (fig[1 ,3 ], f, ; tellwidth= false , halign= :right , valign= :top ) # , margin=(-10, -10, 10, 10)
652696 fig
653- save_with_config (" intermediate/compare_hmc_hvi_sites_y " , fig; makie_config)
697+ save_with_config (" intermediate/compare_hmc_hvi_sites_y_ $( last (scenario)) " , fig; makie_config)
654698 # hvi predicts y better, hmc fails for quite a few obs: 3,5,6
655699
656700 # todo compare mean_predictions
0 commit comments