@@ -28,35 +28,34 @@ gdev = :use_gpu ∈ scenario ? gpu_device() : identity
2828cdev = gdev isa MLDataDevices. AbstractGPUDevice ? cpu_device() : identity
2929
3030# ------ setup synthetic data and training data loader
31+ prob0_ = HybridProblem(DoubleMM. DoubleMMCase(); scenario);
3132(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
32- ) = gen_hybridproblem_synthetic(rng, DoubleMM . DoubleMMCase() ; scenario);
33- # n_site = get_hybridproblem_n_site(DoubleMM.DoubleMMCase() ; scenario)
33+ ) = gen_hybridproblem_synthetic(rng, prob0_ ; scenario);
34+ n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_ ; scenario)
3435ζP_true, ζMs_true = log.(θP_true), log.(θMs_true)
3536i_sites = 1 : n_site
36- xM_cpu = xM;
37- xM = xM_cpu |> gdev;
38- get_train_loader = (; n_batch, kwargs... ) -> MLUtils. DataLoader(
37+ n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
38+ train_dataloader = MLUtils. DataLoader(
3939 (xM, xP, y_o, y_unc, 1 : n_site);
4040 batchsize = n_batch, partial = false )
4141σ_o = exp.(y_unc[:, 1 ] / 2 )
42-
4342# assign the train_loader, otherwise it eatch time creates another version of synthetic data
44- prob0 = HVI. update(HybridProblem(DoubleMM . DoubleMMCase(); scenario); get_train_loader)
43+ prob0 = HVI. update(prob0_; train_dataloader);
4544# tmp = HVI.get_hybridproblem_ϕunc(prob0; scenario)
4645
4746# ------- pointwise hybrid model fit
48- solver_point = HybridPointSolver(; alg = OptimizationOptimisers. Adam(0.01 ), n_batch = 30 )
47+ solver_point = HybridPointSolver(; alg = OptimizationOptimisers. Adam(0.01 ))
4948# solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 30)
5049# solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
5150# solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
52- n_batches_in_epoch = n_site ÷ solver_point . n_batch
51+ n_batches_in_epoch = n_site ÷ n_batch
5352n_epoch = 80
5453(; ϕ, resopt, probo) = solve(prob0, solver_point; scenario,
5554 rng, callback = callback_loss(n_batches_in_epoch * 10 ),
5655 maxiters = n_batches_in_epoch * n_epoch);
5756# update the problem with optimized parameters
5857prob0o = probo;
59- y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
58+ y_pred_global, y_pred, θMs = gf(prob0o, scenario);
6059plt = scatterplot(θMs_true[1 , :], θMs[1 , :]);
6160lineplot!(plt, 0 , 1 )
6261scatterplot(θMs_true[2 , :], θMs[2 , :])
@@ -149,10 +148,10 @@ probh = prob0o # start from point optimized to infer uncertainty
149148# probh = prob1o # start from point optimized to infer uncertainty
150149# probh = prob0 # start from no information
151150solver_post = HybridPosteriorSolver(;
152- alg = OptimizationOptimisers. Adam(0.01 ), n_batch = min( 50 , n_site), n_MC = 3 )
151+ alg = OptimizationOptimisers. Adam(0.01 ), n_MC = 3 )
153152# solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
154- n_batches_in_epoch = n_site ÷ solver_post . n_batch
155- n_epoch = 80
153+ n_batches_in_epoch = n_site ÷ n_batch
154+ n_epoch = 40
156155(; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario,
157156 rng, callback = callback_loss(n_batches_in_epoch * 5 ),
158157 maxiters = n_batches_in_epoch * n_epoch,
213212 n_sample_pred = 400
214213 (; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
215214 (θ2_indep, y2_indep) = (θ, y)
215+ # (θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed
216216end
217217
218218() -> begin # otpimize using LUX
@@ -246,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :])
246246
247247# test predicting correct obs-uncertainty of predictive posterior
248248n_sample_pred = 400
249- (; θ, y, entropy_ζ) = predict_gf(rng, prob2o, xM, xP ; scenario, n_sample_pred);
249+ (; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred);
250250(θ2, y2) = (θ, y)
251251size(y) # n_obs x n_site, n_sample_pred
252252size(θ) # n_θP + n_site * n_θM x n_sample
@@ -506,12 +506,13 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread
506506 using JLD2
507507 fname = " intermediate/doubleMM_chain_zeta_$(last(scenario)) .jld2"
508508 jldsave(fname, false , IOStream; chain)
509- chain = load(fname, " chain" ; iotype = IOStream)
509+ chain = load(fname, " chain" ; iotype = IOStream);
510510end
511511
512512# ζi = first(eachrow(Array(chain)))
513+ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true )
513514ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true ), hcat, eachrow(Array(chain)));
514- (; θ, y) = HVI. predict_ζf(ζs, f , xP, trans_PMs_gen, intm_PMs_gen);
515+ (; θ, y) = HVI. predict_ζf(ζs, f_allsites , xP, trans_PMs_gen, intm_PMs_gen);
515516(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
516517
517518
0 commit comments