@@ -7,7 +7,7 @@ using Statistics
77using ComponentArrays: ComponentArrays as CA
88
99using SimpleChains
10- import Flux # to allow for FluxMLEngine and cpu()
10+ import Flux
1111using MLUtils
1212import Zygote
1313
@@ -17,41 +17,43 @@ using Bijectors
1717using UnicodePlots
1818
1919const case = DoubleMM. DoubleMMCase ()
20- const MLengine = Val (nameof (SimpleChains))
21- const FluxMLengine = Val (nameof (Flux))
2220scenario = (:default ,)
2321rng = StableRNG (111 )
2422
2523par_templates = get_hybridcase_par_templates (case; scenario)
2624
27- (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
25+ # n_covar = get_hybridcase_n_covar(case; scenario)
26+ # , n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
27+
28+ (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29+ ) = gen_hybridcase_synthetic (rng, case; scenario);
30+
31+ n_covar = size (xM,1 )
2832
29- (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
30- ) = gen_hybridcase_synthetic (case, rng; scenario);
3133
3234# ----- fit g to θMs_true
33- g, ϕg0 = get_hybridcase_MLapplicator (case, MLengine; scenario);
35+ g, ϕg0 = get_hybridcase_MLapplicator (case; scenario);
36+ (; transP, transM) = get_hybridcase_transforms (case; scenario)
3437
35- function loss_g (ϕg, x, g)
38+ function loss_g (ϕg, x, g, transM )
3639 ζMs = g (x, ϕg) # predict the log of the parameters
37- θMs = exp .( ζMs)
40+ θMs = reduce (hcat, map (transM, eachcol ( ζMs))) # transform each column
3841 loss = sum (abs2, θMs .- θMs_true)
3942 return loss, θMs
4043end
41- loss_g (ϕg0, xM, g)
42- Zygote. gradient (x -> loss_g (x, xM, g)[1 ], ϕg0);
44+ loss_g (ϕg0, xM, g, transM)
4345
44- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g)[1 ],
46+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM )[1 ],
4547 Optimization. AutoZygote ())
4648optprob = Optimization. OptimizationProblem (optf, ϕg0);
4749res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 800 );
4850
4951ϕg_opt1 = res. u;
50- loss_g (ϕg_opt1, xM, g)
51- scatterplot (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ]))
52- @test cor (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ])) > 0.9
52+ l1, θMs_pred = loss_g (ϕg_opt1, xM, g, transM)
53+ scatterplot (vec (θMs_true), vec (θMs_pred))
5354
5455f = get_hybridcase_PBmodel (case; scenario)
56+ py = get_hybridcase_neg_logden_obs (case; scenario)
5557
5658# ----------- fit g and θP to y_o
5759() -> begin
@@ -62,7 +64,7 @@ f = get_hybridcase_PBmodel(case; scenario)
6264 p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ) # slightly disturb θP_true
6365
6466 # Pass the site-data for the batches as separate vectors wrapped in a tuple
65- train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
67+ train_loader = MLUtils. DataLoader ((xM, xP, y_o, y_unc ), batchsize = n_batch)
6668
6769 loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
6870 l1 = loss_gf (p0, train_loader. data... )[1 ]
@@ -82,15 +84,16 @@ f = get_hybridcase_PBmodel(case; scenario)
8284end
8385
8486# ---------- HVI
85- logσ2y = 2 .* log .(σ_o)
8687n_MC = 3
87- transP = elementwise (exp )
88- transM = Stacked ( elementwise (identity), elementwise (exp) )
88+ (; transP, transM) = get_hybridcase_transforms (case; scenario )
89+ FT = get_hybridcase_float_type (case; scenario )
8990
9091(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params (
91- θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP = asℝ₊ , transM = asℝ₊ );
92+ θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP, transM);
9293ϕ_true = ϕ
9394
95+
96+
9497() -> begin
9598 coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
9699 logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
@@ -149,49 +152,22 @@ transM = Stacked(elementwise(identity), elementwise(exp))
149152 ϕ_true = inverse_ca (trans_gu, ϕt_true)
150153end
151154
152- ϕ_ini0 = ζ = vcat (ϕ_true[:μP ] .* 0.0 , ϕg0, ϕ_true[[:unc ]]); # scratch
155+ ϕ_ini0 = ζ = reduce (
156+ vcat, (
157+ ϕ_true[[:μP ]] .* FT (0.001 ), CA. ComponentVector (ϕg = ϕg0), ϕ_true[[:unc ]])) # scratch
153158#
154- # true values
155- ϕ_ini = ζ = vcat (ϕ_true[[:μP , :ϕg ]] .* 1.2 , ϕ_true[[:unc ]]); # slight disturbance
159+ ϕ_ini = ζ = reduce (
160+ vcat, (
161+ ϕ_true[[:μP ]] .- FT (0.1 ), ϕ_true[[:ϕg ]] .* FT (1.1 ), ϕ_true[[:unc ]])) # slight disturbance
156162# hardcoded from HMC inversion
157163ϕ_ini. unc. coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
158164ϕ_ini. unc. logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
159165mean_σ_o_MC = 0.006042
160166
161- # test cost function and gradient
162- () -> begin
163- neg_elbo_transnorm_gf (rng, g, f, ϕ_true, y_o[:, 1 : n_batch], xM[:, 1 : n_batch],
164- transPMs_batch, map (get_concrete, interpreters);
165- n_MC = 8 , logσ2y)
166- Zygote. gradient (
167- ϕ -> neg_elbo_transnorm_gf (
168- rng, g, f, ϕ, y_o[:, 1 : n_batch], xM[:, 1 : n_batch],
169- transPMs_batch, interpreters; n_MC = 8 , logσ2y),
170- CA. getdata (ϕ_true))
171- end
172-
173- # optimize using SimpleChains
174- () -> begin
175- train_loader = MLUtils. DataLoader ((xM, y_o), batchsize = n_batch)
176-
177- optf = Optimization. OptimizationFunction (
178- (ϕ, data) -> begin
179- xM, y_o = data
180- neg_elbo_transnorm_gf (
181- rng, g, f, ϕ, y_o, xM, transPMs_batch,
182- map (get_concrete, interpreters_g); n_MC = 5 , logσ2y)
183- end ,
184- Optimization. AutoZygote ())
185- optprob = Optimization. OptimizationProblem (optf, CA. getdata (ϕ_ini), train_loader)
186- res = Optimization. solve (
187- optprob, Optimisers. Adam (0.02 ), callback = callback_loss (50 ), maxiters = 800 )
188- # optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
189- # res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
190- end
191-
192- ϕ = ϕ_ini |> Flux. gpu;
167+ ϕ = CA. getdata (ϕ_ini) |> Flux. gpu;
193168xM_gpu = xM |> Flux. gpu;
194- g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator (case, FluxMLengine; scenario);
169+ scenario_flux = (scenario... , :use_Flux )
170+ g_flux, _ = get_hybridcase_MLapplicator (case; scenario = scenario_flux);
195171
196172# otpimize using LUX
197173() -> begin
@@ -216,27 +192,25 @@ g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario
216192 g_flux = g_luxs
217193end
218194
219- function fcost (ϕ, xM, y_o)
220- neg_elbo_transnorm_gf (rng, g_flux, f, CA. getdata (ϕ), y_o ,
221- xM, transPMs_batch , map (get_concrete, interpreters);
222- n_MC = 8 , logσ2y = logσ2y )
195+ function fcost (ϕ, xM, y_o, y_unc )
196+ neg_elbo_transnorm_gf (rng, CA. getdata (ϕ), g_flux, transPMs_batch, f, py ,
197+ xM, xP, y_o, y_unc , map (get_concrete, interpreters);
198+ n_MC = 8 )
223199end
224- fcost (ϕ, xM_gpu[:, 1 : n_batch], y_o[:, 1 : n_batch])
200+ fcost (ϕ, xM_gpu[:, 1 : n_batch], y_o[:, 1 : n_batch], y_unc[:, 1 : n_batch] )
225201# Zygote.gradient(fcost, ϕ) |> cpu;
226202gr = Zygote. gradient (fcost,
227- CA. getdata (ϕ), CA. getdata (xM_gpu[:, 1 : n_batch]), CA. getdata (y_o[:, 1 : n_batch]));
228- gr_c = CA. ComponentArray (gr[1 ] |> Flux. cpu, CA. getaxes (ϕ)... )
203+ CA. getdata (ϕ), CA. getdata (xM_gpu[:, 1 : n_batch]),
204+ CA. getdata (y_o[:, 1 : n_batch]), CA. getdata (y_unc[:, 1 : n_batch]));
205+ gr_c = CA. ComponentArray (gr[1 ] |> Flux. cpu, CA. getaxes (ϕ_ini)... )
229206
230- train_loader = MLUtils. DataLoader ((xM_gpu, xP, y_o), batchsize = n_batch)
231- train_loader = get_hybridcase_train_dataloader (case, rng; scenario = (scenario... , :use_flux ))
207+ train_loader = MLUtils. DataLoader ((xM_gpu, xP, y_o, y_unc ), batchsize = n_batch)
208+ # train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux ))
232209
233210optf = Optimization. OptimizationFunction (
234211 (ϕ, data) -> begin
235- xM, y_o = data
236- fcost (ϕ, xM, y_o)
237- # neg_elbo_transnorm_gf(
238- # rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
239- # map(get_concrete, interpreters); n_MC = 5, logσ2y)
212+ xM, xP, y_o, y_unc = data
213+ fcost (ϕ, xM, y_o, y_unc)
240214 end ,
241215 Optimization. AutoZygote ())
242216optprob = Optimization. OptimizationProblem (
256230ζMs_VI = g_flux (xM_gpu, ζ_VIc. ϕg |> Flux. gpu) |> Flux. cpu
257231ϕunc_VI = interpreters. unc (ζ_VIc. unc)
258232
259- hcat (θP_true, exp .( ζ_VIc. μP) )
233+ hcat (log .( θP_true), ϕ_ini . μP, ζ_VIc. μP)
260234plt = scatterplot (vec (θMs_true), vec (exp .(ζMs_VI)))
261235# lineplot!(plt, 0.0, 1.1, identity)
262236#
@@ -266,11 +240,12 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
266240# test predicting correct obs-uncertainty of predictive posterior
267241n_sample_pred = 200
268242
269- y_pred = predict_gf (rng, g_flux, f, res. u, xM_gpu, interpreters;
243+ y_pred = predict_gf (rng, g_flux, f, res. u, xM_gpu, xP, interpreters;
270244 get_transPMs, get_ca_int_PMs, n_sample_pred);
271245size (y_pred) # n_obs x n_site, n_sample_pred
272246
273247σ_o_post = dropdims (std (y_pred; dims = 3 ), dims = 3 );
248+ σ_o = exp .(y_unc[:,1 ] / 2 )
274249
275250# describe(σ_o_post)
276251hcat (σ_o, fill (mean_σ_o_MC, length (σ_o)),
@@ -282,7 +257,7 @@ histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_tru
282257
283258# look at θP, θM1 of first site
284259intm_PMs_gen = get_ca_int_PMs (n_site)
285- ζs, _σ = HVI. generate_ζ (rng, g_flux, f, res. u, xM_gpu,
260+ ζs, _σ = HVI. generate_ζ (rng, g_flux, res. u, xM_gpu,
286261 (; interpreters... , PMs = intm_PMs_gen); n_MC = n_sample_pred);
287262ζs = ζs |> Flux. cpu;
288263θPM = vcat (θP_true, θMs_true[:, 1 ])
0 commit comments