@@ -5,82 +5,195 @@ using StableRNGs
55using Random
66using Statistics
77using ComponentArrays: ComponentArrays as CA
8-
8+ using Optimization
9+ using OptimizationOptimisers # Adam
10+ using UnicodePlots
911using SimpleChains
10- import Flux
12+ using Flux
1113using MLUtils
12- import Zygote
13-
1414using CUDA
15- using OptimizationOptimisers
16- using Bijectors
17- using UnicodePlots
18-
19- const prob = DoubleMM. DoubleMMCase ()
20- scenario = (:default ,)
21- rng = StableRNG (111 )
22-
23- par_templates = get_hybridproblem_par_templates (prob; scenario)
2415
25- # n_covar = get_hybridproblem_n_covar(prob; scenario)
26- # , n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
16+ rng = StableRNG (114 )
17+ scenario = NTuple {0, Symbol} ()
18+ scenario = (:use_Flux ,)
2719
20+ # ------ setup synthetic data and training data loader
2821(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29- ) = gen_hybridcase_synthetic (rng, prob; scenario);
30-
31- n_covar = size (xM,1 )
22+ ) = gen_hybridcase_synthetic (rng, DoubleMM. DoubleMMCase (); scenario);
23+ xM_cpu = xM
24+ if :use_Flux ∈ scenario
25+ xM = CuArray (xM_cpu)
26+ end
27+ get_train_loader = (rng; n_batch, kwargs... ) -> MLUtils. DataLoader ((xM, xP, y_o, y_unc);
28+ batchsize = n_batch, partial = false )
29+ σ_o = exp (first (y_unc)/ 2 )
30+
31+ # assign the train_loader, otherwise it eatch time creates another version of synthetic data
32+ prob0 = HVI. update (HybridProblem (DoubleMM. DoubleMMCase (); scenario); get_train_loader)
33+
34+ # ------- pointwise hybrid model fit
35+ solver = HybridPointSolver (; alg = Adam (0.02 ), n_batch = 30 )
36+ # solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
37+ # solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
38+ (; ϕ, resopt) = solve (prob0, solver; scenario,
39+ rng, callback = callback_loss (100 ), maxiters = 1200 );
40+ # update the problem with optimized parameters
41+ prob0o = HVI. update (prob0; ϕg= cpu_ca (ϕ). ϕg, θP= cpu_ca (ϕ). θP)
42+ y_pred_global, y_pred, θMs = gf (prob0o, xM, xP; scenario);
43+ scatterplot (θMs_true[1 ,:], θMs[1 ,:])
44+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
45+
46+ # do a few steps without minibatching,
47+ # by providing the data rather than the DataLoader
48+ solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
49+ (; ϕ, resopt) = solve (prob0o, solver1; scenario, rng,
50+ callback = callback_loss (20 ), maxiters = 600 );
51+ prob1o = HVI. update (prob0o; ϕg= cpu_ca (ϕ). ϕg, θP= cpu_ca (ϕ). θP);
52+ y_pred_global, y_pred, θMs = gf (prob1o, xM, xP; scenario);
53+ scatterplot (θMs_true[1 ,:], θMs[1 ,:])
54+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
55+ prob1o. θP
56+ scatterplot (vec (y_true), vec (y_pred))
57+
58+ # still overestimating θMs
59+
60+ () -> begin # with more iterations?
61+ prob2 = prob1o
62+ (; ϕ, resopt) = solve (prob2, solver1; scenario, rng,
63+ callback = callback_loss (20 ), maxiters = 600 );
64+ prob2o = update (prob2; ϕg= ϕ. ϕg, θP= ϕ. θP)
65+ y_pred_global, y_pred, θMs = gf (prob2o, xM, xP);
66+ prob2o. θP
67+ end
3268
3369
34- # ----- fit g to θMs_true
35- g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
36- (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
70+ # ----------- fit g to true θMs
71+ () -> begin
72+ # and fit gf starting from true parameters
73+ prob = prob0
74+ g, ϕg0_cpu = get_hybridproblem_MLapplicator (prob; scenario);
75+ ϕg0 = (:use_Flux ∈ scenario) ? CuArray (ϕg0_cpu) : ϕg0_cpu
76+ (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
77+
78+ function loss_g (ϕg, x, g, transM; gpu_handler = HVI. default_GPU_DataHandler)
79+ ζMs = g (x, ϕg) # predict the log of the parameters
80+ ζMs_cpu = gpu_handler (ζMs)
81+ θMs = reduce (hcat, map (transM, eachcol (ζMs_cpu))) # transform each column
82+ loss = sum (abs2, θMs .- θMs_true)
83+ return loss, θMs
84+ end
85+ loss_g (ϕg0, xM, g, transM)
86+
87+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
88+ Optimization. AutoZygote ())
89+ optprob = Optimization. OptimizationProblem (optf, ϕg0);
90+ res = Optimization. solve (optprob, Adam (0.015 ), callback = callback_loss (100 ), maxiters = 2000 );
91+
92+ ϕg_opt1 = res. u;
93+ l1, θMs = loss_g (ϕg_opt1, xM, g, transM)
94+ # scatterplot(θMs_true[1,:], θMs[1,:])
95+ scatterplot (θMs_true[2 ,:], θMs[2 ,:]) # able to fit θMs[2,:]
96+
97+ prob3 = HVI. update (prob0, ϕg = Array (ϕg_opt1), θP = θP_true)
98+ solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
99+ (; ϕ, resopt) = solve (prob3, solver1; scenario, rng,
100+ callback = callback_loss (50 ), maxiters = 600 );
101+ prob3o = HVI. update (prob3; ϕg= cpu_ca (ϕ). ϕg, θP= cpu_ca (ϕ). θP)
102+ y_pred_global, y_pred, θMs = gf (prob3o, xM, xP; scenario);
103+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
104+ prob3o. θP
105+ scatterplot (vec (y_true), vec (y_pred))
106+ scatterplot (vec (y_true), vec (y_o))
107+ scatterplot (vec (y_pred), vec (y_o))
37108
38- function loss_g (ϕg, x, g, transM)
39- ζMs = g (x, ϕg) # predict the log of the parameters
40- θMs = reduce (hcat, map (transM, eachcol (ζMs))) # transform each column
41- loss = sum (abs2, θMs .- θMs_true)
42- return loss, θMs
109+ () -> begin # optimized loss is indeed lower than with true parameters
110+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
111+ ϕg = 1 : length (prob0. ϕg), θP = prob0. θP))
112+ loss_gf = get_loss_gf (prob0. g, prob0. transM, prob0. f, Float32[], int_ϕθP)
113+ loss_gf (vcat (prob3. ϕg, prob3. θP), xM, xP, y_o, y_unc)[1 ]
114+ loss_gf (vcat (prob3o. ϕg, prob3o. θP), xM, xP, y_o, y_unc)[1 ]
115+ #
116+ loss_gf (vcat (prob2o. ϕg, prob2o. θP), xM, xP, y_o, y_unc)[1 ]
117+ end
43118end
44- loss_g (ϕg0, xM, g, transM)
119+
120+ # ----------- Hybrid Variational inference: HVI
45121
46- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
47- Optimization. AutoZygote ())
48- optprob = Optimization. OptimizationProblem (optf, ϕg0);
49- res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 800 );
122+ using MLUtils
123+ import Zygote
124+
125+ using CUDA
126+ using Bijectors
50127
51- ϕg_opt1 = res. u;
52- l1, θMs_pred = loss_g (ϕg_opt1, xM, g, transM)
53- scatterplot (vec (θMs_true), vec (θMs_pred))
128+ solver = HybridPosteriorSolver (; alg = Adam (0.01 ), n_batch = 60 , n_MC = 3 )
129+ # solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
130+ (; ϕ, θP, resopt) = solve (prob0o, solver; scenario,
131+ rng, callback = callback_loss (100 ), maxiters = 800 );
132+ # update the problem with optimized parameters
133+ prob1o = HVI. update (prob0o; ϕg= cpu_ca (ϕ). ϕg, θP= θP)
134+ y_pred_global, y_pred, θMs = gf (prob1o, xM, xP; scenario);
135+ scatterplot (θMs_true[1 ,:], θMs[1 ,:])
136+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
137+ hcat (θP_true, θP) # all parameters overestimated
54138
55- f = get_hybridproblem_PBmodel (prob; scenario)
56- py = get_hybridproblem_neg_logden_obs (prob; scenario)
57139
58- # ----------- fit g and θP to y_o
59140() -> begin
60- # end2end inversion
141+ # n_covar = get_hybridproblem_n_covar(prob; scenario)
142+ # , n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
61143
62- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
63- ϕg = 1 : length (ϕg0), θP = par_templates. θP))
64- p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ) # slightly disturb θP_true
144+ n_covar = size (xM, 1 )
65145
66- # Pass the site-data for the batches as separate vectors wrapped in a tuple
67- train_loader = MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
146+ # ----- fit g to θMs_true
147+ g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
148+ (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
68149
69- loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
70- l1 = loss_gf (p0, train_loader. data... )[1 ]
150+ function loss_g (ϕg, x, g, transM)
151+ ζMs = g (x, ϕg) # predict the log of the parameters
152+ θMs = reduce (hcat, map (transM, eachcol (ζMs))) # transform each column
153+ loss = sum (abs2, θMs .- θMs_true)
154+ return loss, θMs
155+ end
156+ loss_g (ϕg0, xM, g, transM)
71157
72- optf = Optimization. OptimizationFunction ((ϕ, data ) -> loss_gf (ϕ, data ... )[1 ],
158+ optf = Optimization. OptimizationFunction ((ϕg, p ) -> loss_g (ϕg, xM, g, transM )[1 ],
73159 Optimization. AutoZygote ())
74- optprob = OptimizationProblem (optf, p0, train_loader)
160+ optprob = Optimization. OptimizationProblem (optf, ϕg0);
161+ res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 800 );
75162
76- res = Optimization. solve (
77- optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 )
163+ ϕg_opt1 = res. u;
164+ l1, θMs_pred = loss_g (ϕg_opt1, xM, g, transM)
165+ scatterplot (vec (θMs_true), vec (θMs_pred))
78166
79- l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
80- scatterplot (vec (θMs_true), vec (θMs))
81- scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
82- scatterplot (vec (y_pred), vec (y_o))
83- hcat (par_templates. θP, int_ϕθP (res. u). θP)
167+ f = get_hybridproblem_PBmodel (prob; scenario)
168+ py = get_hybridproblem_neg_logden_obs (prob; scenario)
169+
170+ # ----------- fit g and θP to y_o
171+ () -> begin
172+ # end2end inversion
173+
174+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
175+ ϕg = 1 : length (ϕg0), θP = par_templates. θP))
176+ p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ) # slightly disturb θP_true
177+
178+ # Pass the site-data for the batches as separate vectors wrapped in a tuple
179+ train_loader = MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
180+
181+ loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
182+ l1 = loss_gf (p0, train_loader. data... )[1 ]
183+
184+ optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
185+ Optimization. AutoZygote ())
186+ optprob = OptimizationProblem (optf, p0, train_loader)
187+
188+ res = Optimization. solve (
189+ optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 )
190+
191+ l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
192+ scatterplot (vec (θMs_true), vec (θMs))
193+ scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
194+ scatterplot (vec (y_pred), vec (y_o))
195+ hcat (par_templates. θP, int_ϕθP (res. u). θP)
196+ end
84197end
85198
86199# ---------- HVI
@@ -92,8 +205,6 @@ FT = get_hybridproblem_float_type(prob; scenario)
92205 θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP, transM);
93206ϕ_true = ϕ
94207
95-
96-
97208() -> begin
98209 coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
99210 logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
@@ -245,7 +356,7 @@ y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
245356size (y_pred) # n_obs x n_site, n_sample_pred
246357
247358σ_o_post = dropdims (std (y_pred; dims = 3 ), dims = 3 );
248- σ_o = exp .(y_unc[:,1 ] / 2 )
359+ σ_o = exp .(y_unc[:, 1 ] / 2 )
249360
250361# describe(σ_o_post)
251362hcat (σ_o, fill (mean_σ_o_MC, length (σ_o)),
0 commit comments