@@ -24,15 +24,16 @@ xM_cpu = xM
2424if :use_Flux ∈ scenario
2525 xM = CuArray (xM_cpu)
2626end
27- get_train_loader = (rng; n_batch, kwargs... ) -> MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
27+ get_train_loader = (rng; n_batch, kwargs... ) -> MLUtils. DataLoader ((xM, xP, y_o, y_unc);
28+ batchsize = n_batch, partial = false )
2829σ_o = exp (first (y_unc)/ 2 )
2930
3031# assign the train_loader, otherwise it eatch time creates another version of synthetic data
3132prob0 = HVI. update (HybridProblem (DoubleMM. DoubleMMCase (); scenario); get_train_loader)
3233
3334# ------- pointwise hybrid model fit
34- # solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
35- solver = HybridPointSolver (; alg = Adam (0.01 ), n_batch = 10 )
35+ solver = HybridPointSolver (; alg = Adam (0.02 ), n_batch = 30 )
36+ # solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
3637# solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
3738(; ϕ, resopt) = solve (prob0, solver; scenario,
3839 rng, callback = callback_loss (100 ), maxiters = 1200 );
@@ -116,70 +117,83 @@ end
116117 end
117118end
118119
119- # ----------- Hybrid Variational inference
120+ # ----------- Hybrid Variational inference: HVI
120121
121122using MLUtils
122123import Zygote
123124
124125using CUDA
125126using Bijectors
126127
128+ solver = HybridPosteriorSolver (; alg = Adam (0.01 ), n_batch = 60 , n_MC = 3 )
129+ # solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
130+ (; ϕ, θP, resopt) = solve (prob0o, solver; scenario,
131+ rng, callback = callback_loss (100 ), maxiters = 800 );
132+ # update the problem with optimized parameters
133+ prob1o = HVI. update (prob0o; ϕg= cpu_ca (ϕ). ϕg, θP= θP)
134+ y_pred_global, y_pred, θMs = gf (prob1o, xM, xP; scenario);
135+ scatterplot (θMs_true[1 ,:], θMs[1 ,:])
136+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
137+ hcat (θP_true, θP) # all parameters overestimated
127138
128- # n_covar = get_hybridproblem_n_covar(prob; scenario)
129- # , n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
130139
131- n_covar = size (xM, 1 )
140+ () -> begin
141+ # n_covar = get_hybridproblem_n_covar(prob; scenario)
142+ # , n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
132143
133- # ----- fit g to θMs_true
134- g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
135- (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
144+ n_covar = size (xM, 1 )
136145
137- function loss_g (ϕg, x, g, transM)
138- ζMs = g (x, ϕg) # predict the log of the parameters
139- θMs = reduce (hcat, map (transM, eachcol (ζMs))) # transform each column
140- loss = sum (abs2, θMs .- θMs_true)
141- return loss, θMs
142- end
143- loss_g (ϕg0, xM, g, transM)
146+ # ----- fit g to θMs_true
147+ g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
148+ (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
144149
145- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
146- Optimization. AutoZygote ())
147- optprob = Optimization. OptimizationProblem (optf, ϕg0);
148- res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 800 );
150+ function loss_g (ϕg, x, g, transM)
151+ ζMs = g (x, ϕg) # predict the log of the parameters
152+ θMs = reduce (hcat, map (transM, eachcol (ζMs))) # transform each column
153+ loss = sum (abs2, θMs .- θMs_true)
154+ return loss, θMs
155+ end
156+ loss_g (ϕg0, xM, g, transM)
157+
158+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
159+ Optimization. AutoZygote ())
160+ optprob = Optimization. OptimizationProblem (optf, ϕg0);
161+ res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 800 );
149162
150- ϕg_opt1 = res. u;
151- l1, θMs_pred = loss_g (ϕg_opt1, xM, g, transM)
152- scatterplot (vec (θMs_true), vec (θMs_pred))
163+ ϕg_opt1 = res. u;
164+ l1, θMs_pred = loss_g (ϕg_opt1, xM, g, transM)
165+ scatterplot (vec (θMs_true), vec (θMs_pred))
153166
154- f = get_hybridproblem_PBmodel (prob; scenario)
155- py = get_hybridproblem_neg_logden_obs (prob; scenario)
167+ f = get_hybridproblem_PBmodel (prob; scenario)
168+ py = get_hybridproblem_neg_logden_obs (prob; scenario)
156169
157- # ----------- fit g and θP to y_o
158- () -> begin
159- # end2end inversion
170+ # ----------- fit g and θP to y_o
171+ () -> begin
172+ # end2end inversion
160173
161- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
162- ϕg = 1 : length (ϕg0), θP = par_templates. θP))
163- p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ) # slightly disturb θP_true
174+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
175+ ϕg = 1 : length (ϕg0), θP = par_templates. θP))
176+ p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ) # slightly disturb θP_true
164177
165- # Pass the site-data for the batches as separate vectors wrapped in a tuple
166- train_loader = MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
178+ # Pass the site-data for the batches as separate vectors wrapped in a tuple
179+ train_loader = MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
167180
168- loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
169- l1 = loss_gf (p0, train_loader. data... )[1 ]
181+ loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
182+ l1 = loss_gf (p0, train_loader. data... )[1 ]
170183
171- optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
172- Optimization. AutoZygote ())
173- optprob = OptimizationProblem (optf, p0, train_loader)
184+ optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
185+ Optimization. AutoZygote ())
186+ optprob = OptimizationProblem (optf, p0, train_loader)
174187
175- res = Optimization. solve (
176- optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 )
188+ res = Optimization. solve (
189+ optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 )
177190
178- l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
179- scatterplot (vec (θMs_true), vec (θMs))
180- scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
181- scatterplot (vec (y_pred), vec (y_o))
182- hcat (par_templates. θP, int_ϕθP (res. u). θP)
191+ l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
192+ scatterplot (vec (θMs_true), vec (θMs))
193+ scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
194+ scatterplot (vec (y_pred), vec (y_o))
195+ hcat (par_templates. θP, int_ϕθP (res. u). θP)
196+ end
183197end
184198
185199# ---------- HVI
0 commit comments