Skip to content

Commit f199b31

Browse files
committed
implement Ml model depending on PBM parameters
1 parent 841fecb commit f199b31

File tree

7 files changed

+196
-66
lines changed

7 files changed

+196
-66
lines changed

.github/workflows/CompatHelper.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ jobs:
1414
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1515
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
1616
run: julia -e 'using CompatHelper; CompatHelper.main()'
17-
keepalive-job:
18-
name: Keepalive Workflow
19-
runs-on: ubuntu-latest
20-
permissions:
21-
actions: write
22-
steps:
23-
- uses: actions/checkout@v4
24-
with:
25-
ref: 'keepalive' # The branch, tag or SHA to checkout.
26-
- uses: gautamkrishnar/keepalive-workflow@v2
17+
# keepalive-job:
18+
# name: Keepalive Workflow
19+
# runs-on: ubuntu-latest
20+
# permissions:
21+
# actions: write
22+
# steps:
23+
# - uses: actions/checkout@v4
24+
# with:
25+
# ref: 'keepalive' # The branch, tag or SHA to checkout.
26+
# - uses: gautamkrishnar/keepalive-workflow@v2

dev/doubleMM.jl

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ rng = StableRNG(115)
1717
scenario = NTuple{0, Symbol}()
1818
scenario = (:omit_r0,) # without omit_r0 ambiguous K2 estimated to high
1919
scenario = (:use_Flux, :use_gpu)
20-
scenario = (:use_Flux, :use_gpu, :omit_r0)
2120
scenario = (:use_Flux, :use_gpu, :omit_r0, :few_sites)
21+
scenario = (:use_Flux, :use_gpu, :omit_r0, :few_sites, :covarK2)
22+
scenario = (:use_Flux, :use_gpu, :omit_r0, :sites20, :covarK2)
23+
scenario = (:use_Flux, :use_gpu, :omit_r0)
24+
scenario = (:use_Flux, :use_gpu, :omit_r0, :covarK2)
2225
# prob = DoubleMM.DoubleMMCase()
2326

2427
gdev = :use_gpu scenario ? gpu_device() : identity
@@ -48,11 +51,11 @@ solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01), n_ba
4851
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
4952
n_batches_in_epoch = n_site ÷ solver_point.n_batch
5053
n_epoch = 80
51-
(; ϕ, resopt) = solve(prob0, solver_point; scenario,
54+
(; ϕ, resopt, probo) = solve(prob0, solver_point; scenario,
5255
rng, callback = callback_loss(n_batches_in_epoch * 10),
5356
maxiters = n_batches_in_epoch * n_epoch);
5457
# update the problem with optimized parameters
55-
prob0o = HVI.update(prob0; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP)
58+
prob0o = probo;
5659
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
5760
plt = scatterplot(θMs_true[1, :], θMs[1, :]);
5861
lineplot!(plt, 0, 1)
@@ -146,20 +149,21 @@ probh = prob0o # start from point optimized to infer uncertainty
146149
#probh = prob1o # start from point optimized to infer uncertainty
147150
#probh = prob0 # start from no information
148151
solver_post = HybridPosteriorSolver(;
149-
alg = OptimizationOptimisers.Adam(0.01), n_batch = 50, n_MC = 3)
152+
alg = OptimizationOptimisers.Adam(0.01), n_batch = min(50, n_site), n_MC = 3)
150153
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
151154
n_batches_in_epoch = n_site ÷ solver_post.n_batch
152155
n_epoch = 80
153-
(; ϕ, θP, resopt, interpreters) = solve(probh, solver_post; scenario,
154-
rng, callback = callback_loss(n_batches_in_epoch * n_epoch),
156+
(; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario,
157+
rng, callback = callback_loss(n_batches_in_epoch * 5),
155158
maxiters = n_batches_in_epoch * n_epoch,
156159
θmean_quant = 0.05);
157-
probh.get_train_loader(;n_batch = 50, scenario)
160+
#probh.get_train_loader(;n_batch = 50, scenario)
158161
# update the problem with optimized parameters, including uncertainty
159-
probo = prob1o = HVI.update(prob0o; ϕg = cpu_ca(ϕ).ϕg, θP = θP, ϕunc = cpu_ca(ϕ).unc)
162+
prob1o = probo;
160163
n_sample_pred = 400
161-
(; θ, y) = predict_gf(rng, prob1o, xM, xP; scenario, n_sample_pred);
162-
(θ1, y1) = (θ, y)
164+
#(; θ, y) = predict_gf(rng, prob1o, xM, xP; scenario, n_sample_pred);
165+
(; θ, y) = predict_gf(rng, prob1o; scenario, n_sample_pred);
166+
(θ1, y1) = (θ, y);
163167

164168
() -> begin # prediction with fitted parameters (should be smaller than mean)
165169
y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario)
@@ -170,7 +174,7 @@ n_sample_pred = 400
170174
end
171175

172176
#----------- continue HVI without strong prior on θmean
173-
prob2 = HVI.update(prob1o)
177+
prob2 = HVI.update(prob1o); # copy
174178
function fstate_ϕunc(state)
175179
u = state.u |> cpu
176180
#Main.@infiltrate_main
@@ -179,15 +183,36 @@ function fstate_ϕunc(state)
179183
end
180184
n_epoch = 100
181185
#n_epoch = 400
182-
(; ϕ, θP, resopt, interpreters) = solve(prob2, HVI.update(solver_post, n_MC = 12);
186+
(; ϕ, θP, resopt, interpreters, probo) = solve(prob2,
187+
HVI.update(solver_post, n_MC = 12);
188+
#HVI.update(solver_post, n_MC = 30);
183189
scenario, rng, maxiters = n_batches_in_epoch * n_epoch,
184190
callback = HVI.callback_loss_fstate(n_batches_in_epoch*5, fstate_ϕunc));
185-
probo = prob2o = HVI.update(prob2; ϕg = cpu_ca(ϕ).ϕg, θP = θP, ϕunc = cpu_ca(ϕ).unc);
191+
prob2o = probo;
186192

187193
() -> begin
188194
using JLD2
189-
fname_probos = "intermediate/probos.jld2"
195+
#fname_probos = "intermediate/probos_$(last(scenario)).jld2"
196+
fname_probos = "intermediate/probos800_$(last(scenario)).jld2"
190197
JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o))
198+
tmp = JLD2.load(fname_probos)
199+
# get_train_loader function could not be restored with JLD2
200+
prob1o = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader);
201+
prob2o = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
202+
end
203+
204+
() -> begin # load the non-covar scenario
205+
using JLD2
206+
#fname_probos = "intermediate/probos_$(last(scenario)).jld2"
207+
fname_probos = "intermediate/probos800_omit_r0.jld2"
208+
tmp = JLD2.load(fname_probos)
209+
# get_train_loader function could not be restored with JLD2
210+
prob1o_indep = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader);
211+
prob2o_indep = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
212+
# test predicting correct obs-uncertainty of predictive posterior
213+
n_sample_pred = 400
214+
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
215+
(θ2_indep, y2_indep) = (θ, y)
191216
end
192217

193218
() -> begin # otpimize using LUX
@@ -221,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :])
221246

222247
# test predicting correct obs-uncertainty of predictive posterior
223248
n_sample_pred = 400
224-
(; θ, y, entropy_ζ) = predict_gf(rng, probo, xM, xP; scenario, n_sample_pred);
249+
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o, xM, xP; scenario, n_sample_pred);
225250
(θ2, y2) = (θ, y)
226251
size(y) # n_obs x n_site, n_sample_pred
227252
size(θ) # n_θP + n_site * n_θM x n_sample
@@ -243,6 +268,7 @@ mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ
243268
plt = scatterplot(θMs_true[1, :], mean_θ.Ms[1, :]);
244269
lineplot!(plt, 0, 1)
245270
plt = scatterplot(θMs_true[2, :], mean_θ.Ms[2, :])
271+
histogram(θ[:P,:])
246272
#scatter(fig[1,1], CA.getdata(θMs_true[1, :]), CA.getdata(mean_θ.Ms[1, :])); ablines!(fig[1,1], 0, 1)
247273
#@usingany AlgebraOfGraphices
248274
#fig = Figure()
@@ -332,9 +358,9 @@ end
332358
fig = Figure(; size = (640, 480))
333359
fig = Figure(; size = (320, 240))
334360
gp = fig[1, 1]
335-
f = draw!(gp, plt + plth)
361+
fd = draw!(gp, plt + plth)
336362
legend!(
337-
gp, f; tellwidth = false, halign = :right, valign = :top, margin = (10, 10, 10, 10))
363+
gp, fd; tellwidth = false, halign = :right, valign = :top, margin = (10, 10, 10, 10))
338364
save("r1_density.pdf", fig)
339365
save("tmp.svg", fig)
340366

@@ -349,9 +375,9 @@ end
349375
plth = mapping([0.0]) * visual(VLines; linestyle = :dash)
350376
fig = Figure(; size = (640, 480))
351377
gp = fig[1, 1]
352-
f = draw!(gp, plt + plth)
378+
fg = draw!(gp, plt + plth)
353379
legend!(
354-
gp, f; tellwidth = false, halign = :right, valign = :top, margin = (10, 10, 10, 10))
380+
gp, fg; tellwidth = false, halign = :right, valign = :top, margin = (10, 10, 10, 10))
355381
save("ys_density.pdf", fig)
356382
save("tmp.svg", fig)
357383

@@ -362,9 +388,9 @@ end
362388
aog.density()
363389
fig = Figure()
364390
gp = fig[1, 1]
365-
f = draw!(gp, plt)
391+
fg = draw!(gp, plt)
366392
legend!(
367-
gp, f; tellwidth = false, halign = :right, valign = :top, margin = (10, 10, 10, 10))
393+
gp, fg; tellwidth = false, halign = :right, valign = :top, margin = (10, 10, 10, 10))
368394
save("negLogDensity.pdf", fig)
369395
save("tmp.svg", fig)
370396
end
@@ -444,7 +470,9 @@ g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario)
444470
θP, θM, cor_ends, ϕg0, n_site; transP, transM, ϕunc0);
445471

446472
intm_PMs_gen = get_ca_int_PMs(n_site);
473+
#intm_PMs_gen = get_ca_int_PMs(100);
447474
trans_PMs_gen = get_transPMs(n_site);
475+
#trans_PMs_gen = get_transPMs(100);
448476

449477
"""
450478
ζMs in chain are all first parameter, all second parameters, ...
@@ -465,7 +493,8 @@ end
465493
# mle_estimate.values
466494

467495
# takes ~ 25 minutes
468-
n_sample_NUTS = 800
496+
#n_sample_NUTS = 800
497+
n_sample_NUTS = 2000
469498
#chain = sample(model, NUTS(), n_sample_NUTS, initial_params = ζ0_true .+ 0.001)
470499
#n_sample_NUTS = 24
471500
n_threads = 8
@@ -475,10 +504,15 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread
475504

476505
() -> begin
477506
using JLD2
478-
jldsave("intermediate/doubleMM_chain_zeta.jld2", false, IOStream; chain)
479-
chain = load("intermediate/doubleMM_chain_zeta.jld2", "chain"; iotype = IOStream)
507+
fname = "intermediate/doubleMM_chain_zeta_$(last(scenario)).jld2"
508+
jldsave(fname, false, IOStream; chain)
509+
chain = load(fname, "chain"; iotype = IOStream)
480510
end
481511

512+
#ζi = first(eachrow(Array(chain)))
513+
ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain)));
514+
(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
515+
(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
482516

483517

484518
() -> begin # check that the model predicts the same as HVI-code
@@ -494,21 +528,18 @@ end
494528

495529
() -> begin # plot chain
496530
#@usingany TwMakieHelpers, CairoMakie
497-
ch = chain[:,vcat(1:2,n_θP+n_site+1),:];
531+
# θP and first θMs
532+
ch = chain[:,vcat(1:n_θP, n_θP+1, n_θP+n_site+1),:];
498533
fig = plot_chn(ch)
499534
save("tmp.svg", fig)
500535
end
501536

502-
#ζi = first(eachrow(Array(chain)))
503-
ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain)));
504-
(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
505-
(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
506537

507538
mean_y_invζ = mean_y_hmc = map(mean, eachslice(y_hmc; dims = (1, 2)));
508539
#describe(mean_y_pred - y_o)
509540
histogram(vec(mean_y_invζ) - vec(y_true)) # predictions centered around y_o (or y_true)
510541
plt = scatterplot(vec(y_true), vec(mean_y_invζ));
511-
lineplot!(plt, 0, 2)
542+
lineplot!(plt, 0, 1)
512543
mean(mean_y_invζ - y_true) # still ok
513544

514545
# first site, first prediction
@@ -531,17 +562,25 @@ lineplot!(plt, 0, 1)
531562
() -> begin # compare against HVI sample
532563
#@usingany AlgebraOfGraphics, TwPrototypes, CairoMakie, DataFrames
533564
makie_config = ppt_MakieConfig()
565+
function get_fig_size(cfg; width2height=golden_ratio, xfac=1.0)
566+
cfg = makie_config
567+
x_inch = first(cfg.size_inches) * xfac
568+
y_inch = x_inch / width2height
569+
72 .* (x_inch, y_inch) ./ cfg.pt_per_unit # size_pt
570+
end
571+
534572
ζs_hvi = log.(θ2)
573+
ζs_hvi_indep = log.(θ2_indep)
535574
int_pms = interpreters.PMs
536575
par_pos = int_pms(1:length(int_pms))
537576
i_sites = 1:10
538577
#i_sites = 6:10
539578
#i_sites = 11:15
540-
scen = vcat(fill(:hvi,size(ζs_hvi,2)),fill(:hmc,size(ζs_hmc,2)))
579+
scen = vcat(fill(:hvi,size(ζs_hvi,2)),fill(:hmc,size(ζs_hmc,2)),fill(:hvi_indep,size(ζs_hvi_indep,2)))
541580
dfP = mapreduce(vcat, axes(θP,1)) do i_par
542581
pos = par_pos.P[i_par]
543582
DataFrame(
544-
value = vcat(ζs_hvi[pos,:], ζs_hmc[pos,:]),
583+
value = vcat(ζs_hvi[pos,:], ζs_hmc[pos,:], ζs_hvi_indep[pos,:]),
545584
variable = keys(θP)[i_par],
546585
site = i_sites[1],
547586
Method = scen
@@ -551,7 +590,7 @@ lineplot!(plt, 0, 1)
551590
mapreduce(vcat, axes(θM,1)) do i_par
552591
pos = par_pos.Ms[i_par, i_site]
553592
DataFrame(
554-
value = vcat(ζs_hvi[pos,:], ζs_hmc[pos,:]),
593+
value = vcat(ζs_hvi[pos,:], ζs_hmc[pos,:], ζs_hvi_indep[pos,:]),
555594
variable = keys(θM)[i_par],
556595
site = i_site,
557596
Method = scen
@@ -579,14 +618,25 @@ lineplot!(plt, 0, 1)
579618
end
580619
end
581620
) # vcat
582-
plt = (data(df) * mapping(color=:Method) * density(datalimits=extrema) +
583-
data(df_true) * visual(VLines; color=:blue, linestyle=:dash)) *
584-
mapping(:value=>"", col=:variable => sorter(vcat(keys(θP)..., keys(θM)...)), row = (:site => nonnumeric))
621+
#cf90 = (x) -> quantile(x, [0.05,0.95])
622+
plt = (data(subset(df, :Method => ByRow(((:hvi,:hmc))))) * mapping(:value=> (x -> x ) => "", color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) +
623+
data(df_true) * mapping(:value) * visual(VLines; color=:blue, linestyle=:dash)) *
624+
mapping(col=:variable => sorter(vcat(keys(θP)..., keys(θM)...)), row = (:site => nonnumeric))
585625
fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2));
586-
f = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,));
626+
fg = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,));
587627
fig
588628
save("tmp.svg", fig)
589-
save_with_config("intermediate/compare_hmc_hvi_sites", fig; makie_config)
629+
save_with_config("intermediate/compare_hmc_hvi_sites_$(last(scenario))", fig; makie_config)
630+
631+
plt = (data(subset(df, :Method => ByRow(((:hvi, :hvi_indep))))) * mapping(:value=> (x -> x ) => "", color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) +
632+
data(df_true) * mapping(:value) * visual(VLines; color=:blue, linestyle=:dash)) *
633+
mapping(col=:variable => sorter(vcat(keys(θP)..., keys(θM)...)), row = (:site => nonnumeric))
634+
fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2));
635+
fg = draw!(fig, plt, facet=(; linkxaxes=:minimal, linkyaxes=:none,), axis=(xlabelvisible=false,));
636+
fig
637+
save("tmp.svg", fig)
638+
save_with_config("intermediate/compare_hvi_indep_sites_$(last(scenario))", fig; makie_config)
639+
590640
#
591641
# compare density of predictions
592642
y_hvi = y2
@@ -633,24 +683,18 @@ lineplot!(plt, 0, 1)
633683
end
634684
end
635685

636-
plt = (data(dfy) * mapping(color=:Method) * density(datalimits=extrema) +
686+
plt = (data(dfy) * mapping(color=:Method) * AlgebraOfGraphics.density(datalimits=extrema) +
637687
data(dfyt) * mapping(color=:Method) * visual(VLines; linestyle=:dash)) *
638688
#data(dfyt) * mapping(color=:Method, linestyle=:Method) * visual(VLines; linestyle=:dash)) *
639689
mapping(:value=>"", col=:i_obs => nonnumeric, row = :site => nonnumeric)
640690

641-
function get_fig_size(cfg; width2height=golden_ratio, xfac=1.0)
642-
cfg = makie_config
643-
x_inch = first(cfg.size_inches) * xfac
644-
y_inch = x_inch / width2height
645-
72 .* (x_inch, y_inch) ./ cfg.pt_per_unit # size_pt
646-
end
647691
fig = Figure(size = get_fig_size(makie_config, xfac=1, width2height = 1/2));
648692
f = draw!(fig, plt,
649693
facet=(; linkxaxes=:minimal, linkyaxes=:none,),
650694
axis=(xlabelvisible=false,yticklabelsvisible=false));
651695
legend!(fig[1,3], f, ; tellwidth=false, halign=:right, valign=:top) # , margin=(-10, -10, 10, 10)
652696
fig
653-
save_with_config("intermediate/compare_hmc_hvi_sites_y", fig; makie_config)
697+
save_with_config("intermediate/compare_hmc_hvi_sites_y_$(last(scenario))", fig; makie_config)
654698
# hvi predicts y better, hmc fails for quite a few obs: 3,5,6
655699

656700
# todo compare mean_predictions
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Bijectors
2+
3+
using MLDataDevices
4+
import CUDA, cuDNN
5+
using Zygote
6+
7+
b2 = exp
8+
b2 = elementwise(exp)
9+
10+
x = [0.1, 0.2, 0.3, 0.4]
11+
b2 = Stacked((identity,),(1:length(x),))
12+
13+
function trans(x, b)
14+
y, logjac = Bijectors.with_logabsdet_jacobian(b, x)
15+
sum(y .+ logjac)
16+
end
17+
18+
y = trans(x,b2)
19+
Zygote.gradient(x -> trans(x,b2), x)
20+
21+
xd = gpu_device()(x)
22+
yd = trans(xd, b2)
23+
Zygote.gradient(x -> trans(x,b2), xd) # errors with elementwise

0 commit comments

Comments
 (0)