@@ -26,32 +26,32 @@ par_templates = get_hybridcase_par_templates(case; scenario)
2626
2727(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
2828
29- (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
29+ (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
3030) = gen_hybridcase_synthetic (case, rng; scenario);
3131
3232# ----- fit g to θMs_true
3333g, ϕg0 = get_hybridcase_MLapplicator (case, MLengine; scenario);
34+ (; transP, transM) = get_hybridcase_transforms (case; scenario)
3435
35- function loss_g (ϕg, x, g)
36+ function loss_g (ϕg, x, g, transM )
3637 ζMs = g (x, ϕg) # predict the log of the parameters
37- θMs = exp .( ζMs)
38+ θMs = reduce (hcat, map (transM, eachcol ( ζMs))) # transform each column
3839 loss = sum (abs2, θMs .- θMs_true)
3940 return loss, θMs
4041end
41- loss_g (ϕg0, xM, g)
42- Zygote. gradient (x -> loss_g (x, xM, g)[1 ], ϕg0);
42+ loss_g (ϕg0, xM, g, transM)
4343
44- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g)[1 ],
44+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM )[1 ],
4545 Optimization. AutoZygote ())
4646optprob = Optimization. OptimizationProblem (optf, ϕg0);
4747res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 800 );
4848
4949ϕg_opt1 = res. u;
50- loss_g (ϕg_opt1, xM, g)
51- scatterplot (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ]))
52- @test cor (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ])) > 0.9
50+ l1, θMs_pred = loss_g (ϕg_opt1, xM, g, transM)
51+ scatterplot (vec (θMs_true), vec (θMs_pred))
5352
5453f = get_hybridcase_PBmodel (case; scenario)
54+ py = get_hybridcase_neg_logden_obs (case; scenario)
5555
5656# ----------- fit g and θP to y_o
5757() -> begin
@@ -82,13 +82,12 @@ f = get_hybridcase_PBmodel(case; scenario)
8282end
8383
8484# ---------- HVI
85- logσ2y = 2 .* log .(σ_o)
8685n_MC = 3
87- transP = elementwise (exp )
88- transM = Stacked ( elementwise (identity), elementwise (exp) )
86+ (; transP, transM) = get_hybridcase_transforms (case; scenario )
87+ FT = get_hybridcase_float_type (case; scenario )
8988
9089(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params (
91- θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP = asℝ₊ , transM = asℝ₊ );
90+ θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP, transM);
9291ϕ_true = ϕ
9392
9493() -> begin
@@ -149,49 +148,21 @@ transM = Stacked(elementwise(identity), elementwise(exp))
149148 ϕ_true = inverse_ca (trans_gu, ϕt_true)
150149end
151150
152- ϕ_ini0 = ζ = vcat (ϕ_true[:μP ] .* 0.0 , ϕg0, ϕ_true[[:unc ]]); # scratch
151+ ϕ_ini0 = ζ = reduce (
152+ vcat, (
153+ ϕ_true[[:μP ]] .* FT (0.001 ), CA. ComponentVector (ϕg = ϕg0), ϕ_true[[:unc ]])) # scratch
153154#
154- # true values
155- ϕ_ini = ζ = vcat (ϕ_true[[:μP , :ϕg ]] .* 1.2 , ϕ_true[[:unc ]]); # slight disturbance
155+ ϕ_ini = ζ = reduce (
156+ vcat, (
157+ ϕ_true[[:μP ]] .- FT (0.1 ), ϕ_true[[:ϕg ]] .* FT (1.1 ), ϕ_true[[:unc ]])) # slight disturbance
156158# hardcoded from HMC inversion
157159ϕ_ini. unc. coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
158160ϕ_ini. unc. logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
159161mean_σ_o_MC = 0.006042
160162
161- # test cost function and gradient
162- () -> begin
163- neg_elbo_transnorm_gf (rng, g, f, ϕ_true, y_o[:, 1 : n_batch], xM[:, 1 : n_batch],
164- transPMs_batch, map (get_concrete, interpreters);
165- n_MC = 8 , logσ2y)
166- Zygote. gradient (
167- ϕ -> neg_elbo_transnorm_gf (
168- rng, g, f, ϕ, y_o[:, 1 : n_batch], xM[:, 1 : n_batch],
169- transPMs_batch, interpreters; n_MC = 8 , logσ2y),
170- CA. getdata (ϕ_true))
171- end
172-
173- # optimize using SimpleChains
174- () -> begin
175- train_loader = MLUtils. DataLoader ((xM, y_o), batchsize = n_batch)
176-
177- optf = Optimization. OptimizationFunction (
178- (ϕ, data) -> begin
179- xM, y_o = data
180- neg_elbo_transnorm_gf (
181- rng, g, f, ϕ, y_o, xM, transPMs_batch,
182- map (get_concrete, interpreters_g); n_MC = 5 , logσ2y)
183- end ,
184- Optimization. AutoZygote ())
185- optprob = Optimization. OptimizationProblem (optf, CA. getdata (ϕ_ini), train_loader)
186- res = Optimization. solve (
187- optprob, Optimisers. Adam (0.02 ), callback = callback_loss (50 ), maxiters = 800 )
188- # optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
189- # res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
190- end
191-
192- ϕ = ϕ_ini |> Flux. gpu;
163+ ϕ = CA. getdata (ϕ_ini) |> Flux. gpu;
193164xM_gpu = xM |> Flux. gpu;
194- g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator (case, FluxMLengine; scenario);
165+ g_flux, _ = get_hybridcase_MLapplicator (case, FluxMLengine; scenario);
195166
196167# otpimize using LUX
197168() -> begin
@@ -216,27 +187,25 @@ g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario
216187 g_flux = g_luxs
217188end
218189
219- function fcost (ϕ, xM, y_o)
220- neg_elbo_transnorm_gf (rng, g_flux, f, CA. getdata (ϕ), y_o,
221- xM, transPMs_batch, map (get_concrete, interpreters);
222- n_MC = 8 , logσ2y = logσ2y )
190+ function fcost (ϕ, xM, y_o, y_unc )
191+ neg_elbo_transnorm_gf (rng, g_flux, f, py, CA. getdata (ϕ), y_o, y_unc ,
192+ xM, xP, transPMs_batch, map (get_concrete, interpreters);
193+ n_MC = 8 )
223194end
224- fcost (ϕ, xM_gpu[:, 1 : n_batch], y_o[:, 1 : n_batch])
195+ fcost (ϕ, xM_gpu[:, 1 : n_batch], y_o[:, 1 : n_batch], y_unc[:, 1 : n_batch] )
225196# Zygote.gradient(fcost, ϕ) |> cpu;
226197gr = Zygote. gradient (fcost,
227- CA. getdata (ϕ), CA. getdata (xM_gpu[:, 1 : n_batch]), CA. getdata (y_o[:, 1 : n_batch]));
228- gr_c = CA. ComponentArray (gr[1 ] |> Flux. cpu, CA. getaxes (ϕ)... )
198+ CA. getdata (ϕ), CA. getdata (xM_gpu[:, 1 : n_batch]),
199+ CA. getdata (y_o[:, 1 : n_batch]), CA. getdata (y_unc[:, 1 : n_batch]));
200+ gr_c = CA. ComponentArray (gr[1 ] |> Flux. cpu, CA. getaxes (ϕ_ini)... )
229201
230- train_loader = MLUtils. DataLoader ((xM_gpu, xP, y_o), batchsize = n_batch)
231- train_loader = get_hybridcase_train_dataloader (case, rng; scenario = (scenario... , :use_flux ))
202+ train_loader = MLUtils. DataLoader ((xM_gpu, xP, y_o, y_unc ), batchsize = n_batch)
203+ # train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
232204
233205optf = Optimization. OptimizationFunction (
234206 (ϕ, data) -> begin
235- xM, y_o = data
236- fcost (ϕ, xM, y_o)
237- # neg_elbo_transnorm_gf(
238- # rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
239- # map(get_concrete, interpreters); n_MC = 5, logσ2y)
207+ xM, xP, y_o, y_unc = data
208+ fcost (ϕ, xM, y_o, y_unc)
240209 end ,
241210 Optimization. AutoZygote ())
242211optprob = Optimization. OptimizationProblem (
256225ζMs_VI = g_flux (xM_gpu, ζ_VIc. ϕg |> Flux. gpu) |> Flux. cpu
257226ϕunc_VI = interpreters. unc (ζ_VIc. unc)
258227
259- hcat (θP_true, exp .( ζ_VIc. μP) )
228+ hcat (log .( θP_true), ϕ_ini . μP, ζ_VIc. μP)
260229plt = scatterplot (vec (θMs_true), vec (exp .(ζMs_VI)))
261230# lineplot!(plt, 0.0, 1.1, identity)
262231#
@@ -266,11 +235,12 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
266235# test predicting correct obs-uncertainty of predictive posterior
267236n_sample_pred = 200
268237
269- y_pred = predict_gf (rng, g_flux, f, res. u, xM_gpu, interpreters;
238+ y_pred = predict_gf (rng, g_flux, f, res. u, xM_gpu, xP, interpreters;
270239 get_transPMs, get_ca_int_PMs, n_sample_pred);
271240size (y_pred) # n_obs x n_site, n_sample_pred
272241
273242σ_o_post = dropdims (std (y_pred; dims = 3 ), dims = 3 );
243+ σ_o = exp .(y_unc[:,1 ] / 2 )
274244
275245# describe(σ_o_post)
276246hcat (σ_o, fill (mean_σ_o_MC, length (σ_o)),
0 commit comments