@@ -65,28 +65,245 @@ loss_g(ϕg_opt1, xM, g)
6565scatterplot (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ]))
6666@test cor (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ])) > 0.9
6767
68- # ----------- fit g and θP to y_o
69- f = gen_hybridcase_PBmodel (case; scenario)
68+ tmpf = () -> begin
69+ # ----------- fit g and θP to y_o
70+ # end2end inversion
71+ f = gen_hybridcase_PBmodel (case; scenario)
7072
71- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
72- ϕg = 1 : length (ϕg0), θP = par_templates. θP))
73- p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ); # slightly disturb θP_true
73+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
74+ ϕg = 1 : length (ϕg0), θP = par_templates. θP))
75+ p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ); # slightly disturb θP_true
7476
75- # Pass the site-data for the batches as separate vectors wrapped in a tuple
76- train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
77+ # Pass the site-data for the batches as separate vectors wrapped in a tuple
78+ train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
7779
78- loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
79- l1 = loss_gf (p0, train_loader. data... )[1 ]
80+ loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
81+ l1 = loss_gf (p0, train_loader. data... )[1 ]
8082
81- optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
83+ optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
84+ Optimization. AutoZygote ())
85+ optprob = OptimizationProblem (optf, p0, train_loader)
86+
87+ res = Optimization. solve (
88+ optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 );
89+
90+ l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
91+ scatterplot (vec (θMs_true), vec (θMs))
92+ scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
93+ scatterplot (vec (y_pred), vec (y_o))
94+ hcat (par_templates. θP, int_ϕθP (res. u). θP)
95+ end
96+
97+ # ---------- HADVI
98+ # TODO think about good general initializations
99+ coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
100+ logσ2_logP = CA. ComponentVector (r0= - 8.997 , K2= - 5.893 )
101+ mean_σ_o_MC = 0.006042
102+
103+ # correlation matrices
104+ ρsP = zeros (sum (1 : (n_θP- 1 )))
105+ ρsM = zeros (sum (1 : (n_θM- 1 )))
106+
107+ ϕunc = CA. ComponentVector (;
108+ logσ2_logP= logσ2_logP,
109+ coef_logσ2_logMs= coef_logσ2_logMs,
110+ ρsP,
111+ ρsM)
112+ int_unc = ComponentArrayInterpreter (ϕunc)
113+
114+ # for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
115+ ϕunc0 = CA. ComponentVector (;
116+ logσ2_logP= fill (- 10.0 , n_θP),
117+ coef_logσ2_logMs= reduce (hcat, ([- 10.0 , 0.0 ] for _ in 1 : n_θM)),
118+ ρsP,
119+ ρsM)
120+
121+ logσ2y = fill (2 * log (σ_o), size (y_o, 1 ))
122+ n_MC = 3
123+
124+
125+ # -------------- ADVI with g inside cost function
126+ using CUDA
127+ using TransformVariables
128+
129+ transPMs_batch = as (
130+ (P= as (Array, asℝ₊, n_θP),
131+ Ms= as (Array, asℝ₊, n_θM, n_batch)))
132+ transPMs_all = as (
133+ (P= as (Array, asℝ₊, n_θP),
134+ Ms= as (Array, asℝ₊, n_θM, n_site)))
135+
136+ ϕ_true = θ = CA. ComponentVector (;
137+ μP= θP_true,
138+ ϕg= ϕg_opt,
139+ unc= ϕunc);
140+ trans_gu = as (
141+ (μP= as (Array, asℝ₊, n_θP),
142+ ϕg= as (Array, n_ϕg),
143+ unc= as (Array, length (ϕunc))))
144+ trans_g = as (
145+ (μP= as (Array, asℝ₊, n_θP),
146+ ϕg= as (Array, n_ϕg)))
147+
148+ const int_PMs_batch = ComponentArrayInterpreter (CA. ComponentVector (; θP,
149+ θMs= CA. ComponentMatrix (
150+ zeros (n_θM, n_batch), first (CA. getaxes (θM)), CA. Axis (i= 1 : n_batch))))
151+
152+ interpreters = interpreters_g = map (get_concrete,(;
153+ μP_ϕg_unc= ComponentArrayInterpreter (ϕ_true),
154+ PMs= int_PMs_batch,
155+ unc= ComponentArrayInterpreter (ϕunc)
156+ ))
157+
158+ ϕg_true_vec = CA. ComponentVector (
159+ TransformVariables. inverse (trans_gu, cv2NamedTuple (ϕ_true)))
160+ ϕcg_true = interpreters. μP_ϕg_unc (ϕg_true_vec)
161+ ϕ_ini = ζ = vcat (ϕcg_true[[:μP , :ϕg ]] .* 1.2 , ϕcg_true[[:unc ]]);
162+ ϕ_ini0 = ζ = vcat (ϕcg_true[:μP ] .* 0.0 , SimpleChains. init_params (g), ϕunc0);
163+
164+ neg_elbo_transnorm_gf (rng, g, f, ϕcg_true, y_o[:, 1 : n_batch], x_o[:, 1 : n_batch],
165+ transPMs_batch, map (get_concrete, interpreters);
166+ n_MC= 8 , logσ2y)
167+ Zygote. gradient (ϕ -> neg_elbo_transnorm_gf (
168+ rng, g, f, ϕ, y_o[:, 1 : n_batch], x_o[:, 1 : n_batch],
169+ transPMs_batch, interpreters; n_MC= 8 , logσ2y), ϕcg_true)
170+
171+ () -> begin
172+ train_loader = MLUtils. DataLoader ((x_o, y_o), batchsize = n_batch)
173+
174+ optf = Optimization. OptimizationFunction ((ζg, data) -> begin
175+ x_o, y_o = data
176+ neg_elbo_transnorm_gf (
177+ rng, g, f, ζg, y_o, x_o, transPMs_batch, map (get_concrete, interpreters_g); n_MC= 5 , logσ2y)
178+ end ,
179+ Optimization. AutoZygote ())
180+ optprob = Optimization. OptimizationProblem (optf, CA. getdata (ϕ_ini), train_loader);
181+ res = Optimization. solve (optprob, Optimisers. Adam (0.02 ), callback= callback_loss (50 ), maxiters= 800 );
182+ # optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
183+ # res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
184+ end
185+
186+ # using Lux
187+ ϕ = ϕcg_true |> gpu;
188+ x_o_gpu = x_o |> gpu;
189+ # y_o = y_o |> gpu
190+ # logσ2y = logσ2y |> gpu
191+ n_covar = size (x_o, 1 )
192+ g_flux = Flux. Chain (
193+ # dense layer with bias that maps to 8 outputs and applies `tanh` activation
194+ Flux. Dense (n_covar => n_covar * 4 , tanh),
195+ Flux. Dense (n_covar * 4 => n_covar * 4 , logistic),
196+ # dense layer without bias that maps to n outputs and `identity` activation
197+ Flux. Dense (n_covar * 4 => n_θM, identity, bias= false ),
198+ )
199+ () -> begin
200+ using Lux
201+ g_lux = Lux. Chain (
202+ # dense layer with bias that maps to 8 outputs and applies `tanh` activation
203+ Lux. Dense (n_covar => n_covar * 4 , tanh),
204+ Lux. Dense (n_covar * 4 => n_covar * 4 , logistic),
205+ # dense layer without bias that maps to n outputs and `identity` activation
206+ Lux. Dense (n_covar * 4 => n_θM, identity, use_bias= false ),
207+ )
208+ ps, st = Lux. setup (Random. default_rng (), g_lux)
209+ ps_ca = CA. ComponentArray (ps) |> gpu
210+ st = st |> gpu
211+ g_luxs = StatefulLuxLayer {true} (g_lux, nothing , st)
212+ g_luxs (x_o_gpu[:, 1 : n_batch], ps_ca)
213+ ax_g = CA. getaxes (ps_ca)
214+ g_luxs (x_o_gpu[:, 1 : n_batch], CA. ComponentArray (ϕ. ϕg, ax_g))
215+ interpreters = (interpreters... , ϕg = ComponentArrayInterpreter (ps_ca))
216+ ϕg = CA. ComponentArray (ϕ. ϕg, ax_g)
217+ ϕgc = interpreters. ϕg (ϕ. ϕg)
218+ g_gpu = g_luxs
219+ end
220+ g_gpu = g_flux
221+
222+ # Zygote.gradient(ϕg -> sum(g_gpu(x_o_gpu[:, 1:n_batch],ϕg)), ϕgc)
223+ # Zygote.gradient(ϕg -> sum(compute_g(g_gpu, x_o_gpu[:, 1:n_batch], ϕg, interpreters)), ϕ.ϕg)
224+ # Zygote.gradient(ϕ -> sum(tmp_gen1(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ.ϕg)
225+ # Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), CA.getdata(ϕ))
226+ # Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
227+ # Zygote.gradient(ϕ -> sum(tmp_gen3(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
228+ # Zygote.gradient(ϕ -> sum(tmp_gen4(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)[1]), ϕ) |> cpu
229+ # generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)
230+ # Zygote.gradient(ϕ -> sum(generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)[1]), ϕ) |> cpu
231+ # include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
232+ # neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
233+ # x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)
234+ # Zygote.gradient(ϕ -> sum(neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
235+ # x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)[1]), ϕ) |> cpu
236+
237+
238+ fcost (ϕ) = neg_elbo_transnorm_gf (rng, g_gpu, f, ϕ, y_o[:, 1 : n_batch],
239+ x_o_gpu[:, 1 : n_batch], transPMs_batch, map (get_concrete, interpreters);
240+ n_MC= 8 , logσ2y = logσ2y)
241+ fcost (ϕ)
242+ gr = Zygote. gradient (fcost, ϕ) |> cpu;
243+ Zygote. gradient (fcost, CA. getdata (ϕ))
244+
245+
246+ train_loader = MLUtils. DataLoader ((x_o_gpu, y_o), batchsize = n_batch)
247+
248+ optf = Optimization. OptimizationFunction ((ζg, data) -> begin
249+ x_o, y_o = data
250+ neg_elbo_transnorm_gf (
251+ rng, g_gpu, f, ζg, y_o, x_o, transPMs_batch, map (get_concrete, interpreters_g); n_MC= 5 , logσ2y)
252+ end ,
82253 Optimization. AutoZygote ())
83- optprob = OptimizationProblem (optf, p0, train_loader)
254+ optprob = Optimization. OptimizationProblem (optf, CA. getdata (ϕ_ini) |> gpu, train_loader);
255+ res = res_gpu = Optimization. solve (optprob, Optimisers. Adam (0.02 ), callback= callback_loss (50 ), maxiters= 800 );
256+
257+ ζ_VIc = interpreters_g. μP_ϕg_unc (res. u |> cpu)
258+ ζMs_VI = g (x_o, ζ_VIc. ϕg)
259+ ϕunc_VI = int_unc (ζ_VIc. unc)
260+
261+ hcat (θP_true, exp .(ζ_VIc. μP))
262+ plt = scatterplot (vec (θMs_true), vec (exp .(ζMs_VI)))
263+ # lineplot!(plt, 0.0, 1.1, identity)
264+ #
265+ hcat (ϕunc, ϕunc_VI) # need to compare to MC sample
266+ # hard to estimate for original very small theta's but otherwise good
267+
268+ # test predicting correct obs-uncertainty of predictive posterior
269+ n_sample_pred = 200
270+ intm_PMs_gen = ComponentArrayInterpreter (CA. ComponentVector (; θP,
271+ θMs= CA. ComponentMatrix (
272+ zeros (n_θM, n_site), first (CA. getaxes (θM)), CA. Axis (i= 1 : n_sample_pred))))
273+
274+ include (joinpath (@__DIR__ , " uncNN" , " elbo.jl" )) # callback_loss
275+ ζs, _ = generate_ζ (rng, g, f, res. u |> cpu, x_o,
276+ (;interpreters... , PMs = intm_PMs_gen); n_MC= n_sample_pred)
277+ # ζ = ζs[:,1]
278+ θsc = stack (ζ -> CA. getdata (CA. ComponentVector (
279+ TransformVariables. transform (transPMs_all, ζ))), eachcol (ζs));
280+ y_pred = stack (map (ζ -> first (predict_y (ζ, f, transPMs_all)), eachcol (ζs)));
281+
282+ size (y_pred)
283+ σ_o_post = mapslices (std, y_pred; dims= 3 );
284+ # describe(σ_o_post)
285+ vcat (σ_o, mean_σ_o_MC, mean (σ_o_post), sqrt (mean (abs2, σ_o_post)))
286+ mean_y_pred = map (mean, eachslice (y_pred; dims= (1 , 2 )))
287+ # describe(mean_y_pred - y_o)
288+ histogram (vec (mean_y_pred - y_true)) # predictions centered around y_o (or y_true)
289+
290+ # look at θP, θM1 of first site
291+ intm = ComponentArrayInterpreter (int_θdoubleMM (1 : length (int_θdoubleMM)), (n_sample_pred,))
292+ ζs1c = intm (ζs[1 : length (int_θdoubleMM), :])
293+ vcat (θP_true, θM_true)
294+ histogram (exp .(ζs1c[:r0 , :]))
295+ histogram (exp .(ζs1c[:K2 , :]))
296+ histogram (exp .(ζs1c[:r1 , :]))
297+ histogram (exp .(ζs1c[:K1 , :]))
298+ # all parameters estimated to high (true not in cf bounds)
299+ scatterplot (ζs1c[:r1 , :], ζs1c[:K1 , :]) # r1 and K1 strongly correlated (from θM)
300+ scatterplot (ζs1c[:r0 , :], ζs1c[:K2 , :]) # r0 and K also correlated (from θP)
301+ scatterplot (ζs1c[:r0 , :], ζs1c[:K1 , :]) # no correlation (modeled independent)
302+
303+ # TODO compare distributions to MC sample
304+
305+
306+
307+
84308
85- res = Optimization. solve (
86- optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 );
87309
88- l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
89- scatterplot (vec (θMs_true), vec (θMs))
90- scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
91- scatterplot (vec (y_pred), vec (y_o))
92- hcat (par_templates. θP, int_ϕθP (res. u). θP)
0 commit comments