Skip to content

Commit 7192975

Browse files
committed
compare HVI to HMC inversion
HVI is wrong for thetaP but HMC is often wrong for the means of thetaM and also precitions - did not find global minimum
1 parent 6d2b657 commit 7192975

File tree

1 file changed

+88
-3
lines changed

1 file changed

+88
-3
lines changed

dev/doubleMM.jl

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ end
504504
(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
505505
(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
506506

507-
mean_y_invζ = map(mean, eachslice(y; dims = (1, 2)));
507+
mean_y_invζ = mean_y_hmc = map(mean, eachslice(y_hmc; dims = (1, 2)));
508508
#describe(mean_y_pred - y_o)
509509
histogram(vec(mean_y_invζ) - vec(y_true)) # predictions centered around y_o (or y_true)
510510
plt = scatterplot(vec(y_true), vec(mean_y_invζ));
@@ -534,7 +534,7 @@ lineplot!(plt, 0, 1)
534534
ζs_hvi = log.(θ2)
535535
int_pms = interpreters.PMs
536536
par_pos = int_pms(1:length(int_pms))
537-
#i_sites = 1:5
537+
i_sites = 1:10
538538
#i_sites = 6:10
539539
#i_sites = 11:15
540540
scen = vcat(fill(:hvi,size(ζs_hvi,2)),fill(:hmc,size(ζs_hmc,2)))
@@ -582,9 +582,94 @@ lineplot!(plt, 0, 1)
582582
plt = (data(df) * mapping(color=:Method) * density(datalimits=extrema) +
583583
data(df_true) * visual(VLines; color=:blue, linestyle=:dash)) *
584584
mapping(:value=>"", col=:variable => sorter(vcat(keys(θP)..., keys(θM)...)), row = (:site => nonnumeric))
585-
fig = draw(plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,)).figure
585+
fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2));
586+
f = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,));
587+
fig
586588
save("tmp.svg", fig)
587589
save_with_config("intermediate/compare_hmc_hvi_sites", fig; makie_config)
590+
#
591+
# compare density of predictions
592+
y_hvi = y2
593+
i_obss = [1,4,8]
594+
#i_obss = 1:8
595+
dfy = mapreduce(vcat, i_obss) do i_obs
596+
mapreduce(vcat, i_sites) do i_site
597+
vcat(
598+
DataFrame(
599+
value = y_hmc[i_obs,i_site,:],
600+
site = i_site,
601+
Method = :hmc,
602+
variable = :y,
603+
i_obs = i_obs,
604+
),
605+
DataFrame(
606+
value = y_hvi[i_obs,i_site,:],
607+
site = i_site,
608+
Method = :hvi,
609+
variable = :y,
610+
i_obs = i_obs,
611+
)
612+
)# vcat
613+
end
614+
end
615+
dfyt = mapreduce(vcat, i_obss) do i_obs
616+
mapreduce(vcat, i_sites) do i_site
617+
vcat(
618+
DataFrame(
619+
value = y_true[i_obs,i_site],
620+
site = i_site,
621+
Method = :truth,
622+
variable = :y,
623+
i_obs = i_obs,
624+
),
625+
DataFrame(
626+
value = y_o[i_obs,i_site,:],
627+
site = i_site,
628+
Method = :obs,
629+
variable = :y,
630+
i_obs = i_obs,
631+
)
632+
)# vcat
633+
end
634+
end
635+
636+
plt = (data(dfy) * mapping(color=:Method) * density(datalimits=extrema) +
637+
data(dfyt) * mapping(color=:Method) * visual(VLines; linestyle=:dash)) *
638+
#data(dfyt) * mapping(color=:Method, linestyle=:Method) * visual(VLines; linestyle=:dash)) *
639+
mapping(:value=>"", col=:i_obs => nonnumeric, row = :site => nonnumeric)
640+
641+
function get_fig_size(cfg; width2height=golden_ratio, xfac=1.0)
642+
cfg = makie_config
643+
x_inch = first(cfg.size_inches) * xfac
644+
y_inch = x_inch / width2height
645+
72 .* (x_inch, y_inch) ./ cfg.pt_per_unit # size_pt
646+
end
647+
fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2));
648+
f = draw!(fig, plt,
649+
facet=(; linkxaxes=:minimal, linkyaxes=:none,),
650+
axis=(xlabelvisible=false,yticklabelsvisible=false));
651+
legend!(fig[1,3], f, ; tellwidth=false, halign=:right, valign=:top) # , margin=(-10, -10, 10, 10)
652+
fig
653+
save_with_config("intermediate/compare_hmc_hvi_sites_y", fig; makie_config)
654+
# hvi predicts y better, hmc fails for quite a few obs: 3,5,6
655+
656+
# todo compare mean_predictions
657+
mean_y_hvi = map(mean, eachslice(y_hvi; dims = (1, 2)));
658+
659+
660+
end
661+
662+
() -> begin # inspect correlation of residuals
663+
mean_ζ_hvi = map(mean, eachrow(CA.getdata(ζs_hvi)))
664+
r_hvi = ζs_hvi .- mean_ζ_hvi
665+
cor_hvi = cor(CA.getdata(r_hvi)')
666+
mean_ζ_hmc = map(mean, eachrow(CA.getdata(ζs_hmc)))
667+
r_hmc = ζs_hmc .- mean_ζ_hmc
668+
cor_hmc = cor(CA.getdata(r_hmc)')
669+
#
670+
hcat(cor_hvi[:,1], cor_hmc[:,1])
671+
# positive correlations of K2(1) in θP with K1(3) in θMs
672+
588673
end
589674

590675

0 commit comments

Comments
 (0)