@@ -53,8 +53,8 @@ scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
5353
5454f = gen_hybridcase_PBmodel (case; scenario)
5555
56+ # ----------- fit g and θP to y_o
5657() -> begin
57- # ----------- fit g and θP to y_o
5858 # end2end inversion
5959
6060 int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
@@ -82,67 +82,78 @@ f = gen_hybridcase_PBmodel(case; scenario)
8282end
8383
8484# ---------- HVI
85- # TODO think about good general initializations
86- coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
87- logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
88- mean_σ_o_MC = 0.006042
89-
90- # correlation matrices
91- ρsP = zeros (sum (1 : (n_θP - 1 )))
92- ρsM = zeros (sum (1 : (n_θM - 1 )))
93-
94- ϕunc = CA. ComponentVector (;
95- logσ2_logP = logσ2_logP,
96- coef_logσ2_logMs = coef_logσ2_logMs,
97- ρsP,
98- ρsM)
99- int_unc = ComponentArrayInterpreter (ϕunc)
100-
101- # for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
102- ϕunc0 = CA. ComponentVector (;
103- logσ2_logP = fill (- 10.0 , n_θP),
104- coef_logσ2_logMs = reduce (hcat, ([- 10.0 , 0.0 ] for _ in 1 : n_θM)),
105- ρsP,
106- ρsM)
107-
10885logσ2y = 2 .* log .(σ_o)
10986n_MC = 3
87+ (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params (
88+ θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
89+ ϕ_true = ϕ
11090
111- transPMs_batch = as (
112- (P = as (Array, asℝ₊, n_θP),
113- Ms = as (Array, asℝ₊, n_θM, n_batch)))
114- transPMs_all = as (
115- (P = as (Array, asℝ₊, n_θP),
116- Ms = as (Array, asℝ₊, n_θM, n_site)))
117-
118- n_ϕg = length (ϕg_opt1)
119- ϕt_true = θ = CA. ComponentVector (;
120- μP = θP_true,
121- ϕg = ϕg_opt1,
122- unc = ϕunc);
123- trans_gu = as (
124- (μP = as (Array, asℝ₊, n_θP),
125- ϕg = as (Array, n_ϕg),
126- unc = as (Array, length (ϕunc))))
127- trans_g = as (
128- (μP = as (Array, asℝ₊, n_θP),
129- ϕg = as (Array, n_ϕg)))
130-
131- # const
132- int_PMs_batch = ComponentArrayInterpreter (CA. ComponentVector (; θP = θP_true,
133- θMs = CA. ComponentMatrix (
134- zeros (n_θM, n_batch), first (CA. getaxes (θMs_true)), CA. Axis (i = 1 : n_batch))))
135-
136- interpreters = interpreters_g = map (get_concrete,
137- (;
138- μP_ϕg_unc = ComponentArrayInterpreter (ϕt_true),
139- PMs = int_PMs_batch,
140- unc = ComponentArrayInterpreter (ϕunc)
141- ))
142-
143- ϕ_true = inverse_ca (trans_gu, ϕt_true)
91+ () -> begin
92+ coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
93+ logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
94+ mean_σ_o_MC = 0.006042
95+
96+ # correlation matrices
97+ ρsP = zeros (sum (1 : (n_θP - 1 )))
98+ ρsM = zeros (sum (1 : (n_θM - 1 )))
99+
100+ ϕunc = CA. ComponentVector (;
101+ logσ2_logP = logσ2_logP,
102+ coef_logσ2_logMs = coef_logσ2_logMs,
103+ ρsP,
104+ ρsM)
105+ int_unc = ComponentArrayInterpreter (ϕunc)
106+
107+ # for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
108+ ϕunc0 = CA. ComponentVector (;
109+ logσ2_logP = fill (- 10.0 , n_θP),
110+ coef_logσ2_logMs = reduce (hcat, ([- 10.0 , 0.0 ] for _ in 1 : n_θM)),
111+ ρsP,
112+ ρsM)
113+
114+ transPMs_batch = as (
115+ (P = as (Array, asℝ₊, n_θP),
116+ Ms = as (Array, asℝ₊, n_θM, n_batch)))
117+ transPMs_allsites = as (
118+ (P = as (Array, asℝ₊, n_θP),
119+ Ms = as (Array, asℝ₊, n_θM, n_site)))
120+
121+ n_ϕg = length (ϕg_opt1)
122+ ϕt_true = θ = CA. ComponentVector (;
123+ μP = θP_true,
124+ ϕg = ϕg_opt1,
125+ unc = ϕunc)
126+ trans_gu = as (
127+ (μP = as (Array, asℝ₊, n_θP),
128+ ϕg = as (Array, n_ϕg),
129+ unc = as (Array, length (ϕunc))))
130+ trans_g = as (
131+ (μP = as (Array, asℝ₊, n_θP),
132+ ϕg = as (Array, n_ϕg)))
133+
134+ # const
135+ int_PMs_batch = ComponentArrayInterpreter (CA. ComponentVector (; θP = θP_true,
136+ θMs = CA. ComponentMatrix (
137+ zeros (n_θM, n_batch), first (CA. getaxes (θMs_true)), CA. Axis (i = 1 : n_batch))))
138+
139+ interpreters = interpreters_g = map (get_concrete,
140+ (;
141+ μP_ϕg_unc = ComponentArrayInterpreter (ϕt_true),
142+ PMs = int_PMs_batch,
143+ unc = ComponentArrayInterpreter (ϕunc)
144+ ))
145+
146+ ϕ_true = inverse_ca (trans_gu, ϕt_true)
147+ end
148+
149+ ϕ_ini0 = ζ = vcat (ϕ_true[:μP ] .* 0.0 , ϕg0, ϕ_true[[:unc ]]); # scratch
150+ #
151+ # true values
144152ϕ_ini = ζ = vcat (ϕ_true[[:μP , :ϕg ]] .* 1.2 , ϕ_true[[:unc ]]); # slight disturbance
145- ϕ_ini0 = ζ = vcat (ϕ_true[:μP ] .* 0.0 , ϕg0, ϕunc0); # scratch
153+ # hardcoded from HMC inversion
154+ ϕ_ini. unc. coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
155+ ϕ_ini. unc. logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
156+ mean_σ_o_MC = 0.006042
146157
147158# test cost function and gradient
148159() -> begin
@@ -161,10 +172,10 @@ end
161172 train_loader = MLUtils. DataLoader ((xM, y_o), batchsize = n_batch)
162173
163174 optf = Optimization. OptimizationFunction (
164- (ζg , data) -> begin
175+ (ϕ , data) -> begin
165176 xM, y_o = data
166177 neg_elbo_transnorm_gf (
167- rng, g, f, ζg , y_o, xM, transPMs_batch,
178+ rng, g, f, ϕ , y_o, xM, transPMs_batch,
168179 map (get_concrete, interpreters_g); n_MC = 5 , logσ2y)
169180 end ,
170181 Optimization. AutoZygote ())
@@ -181,7 +192,7 @@ g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario
181192
182193# otpimize using LUX
183194() -> begin
184- using Lux
195+ # using Lux
185196 g_lux = Lux. Chain (
186197 # dense layer with bias that maps to 8 outputs and applies `tanh` activation
187198 Lux. Dense (n_covar => n_covar * 4 , tanh),
@@ -208,18 +219,19 @@ function fcost(ϕ)
208219 n_MC = 8 , logσ2y = logσ2y)
209220end
210221fcost (ϕ)
211- Zygote. gradient (fcost, ϕ) |> cpu;
222+ # Zygote.gradient(fcost, ϕ) |> cpu;
212223gr = Zygote. gradient (fcost, CA. getdata (ϕ));
213- gr_c = CA. ComponentArray (gr[1 ], CA. getaxes (ϕ)... )
224+ gr_c = CA. ComponentArray (gr[1 ] |> Flux . cpu , CA. getaxes (ϕ)... )
214225
215226train_loader = MLUtils. DataLoader ((xM_gpu, y_o), batchsize = n_batch)
216227
217228optf = Optimization. OptimizationFunction (
218- (ζg , data) -> begin
229+ (ϕ , data) -> begin
219230 xM, y_o = data
220- neg_elbo_transnorm_gf (
221- rng, g_flux, f, ζg, y_o, xM, transPMs_batch,
222- map (get_concrete, interpreters_g); n_MC = 5 , logσ2y)
231+ fcost (ϕ)
232+ # neg_elbo_transnorm_gf(
233+ # rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
234+ # map(get_concrete, interpreters); n_MC = 5, logσ2y)
223235 end ,
224236 Optimization. AutoZygote ())
225237optprob = Optimization. OptimizationProblem (
@@ -230,40 +242,31 @@ res = res_gpu = Optimization.solve(
230242# start from zero
231243() -> begin
232244 optprob = Optimization. OptimizationProblem (
233- optf, CA. getdata (ϕ_ini0) |> Flux. gpu, train_loader);
245+ optf, CA. getdata (ϕ_ini0) |> Flux. gpu, train_loader)
234246 res = res_gpu = Optimization. solve (
235- optprob, Optimisers. Adam (0.02 ), callback = callback_loss (50 ), maxiters = 4_000 );
247+ optprob, Optimisers. Adam (0.02 ), callback = callback_loss (50 ), maxiters = 4_000 )
236248end
237249
238- ζ_VIc = interpreters_g . μP_ϕg_unc (res. u |> Flux. cpu)
239- ζMs_VI = g (xM , ζ_VIc. ϕg)
240- ϕunc_VI = int_unc (ζ_VIc. unc)
250+ ζ_VIc = interpreters . μP_ϕg_unc (res. u |> Flux. cpu)
251+ ζMs_VI = g_flux (xM_gpu , ζ_VIc. ϕg |> Flux . gpu) |> Flux . cpu
252+ ϕunc_VI = interpreters . unc (ζ_VIc. unc)
241253
242254hcat (θP_true, exp .(ζ_VIc. μP))
243255plt = scatterplot (vec (θMs_true), vec (exp .(ζMs_VI)))
244256# lineplot!(plt, 0.0, 1.1, identity)
245257#
246- hcat (ϕunc , ϕunc_VI) # need to compare to MC sample
258+ hcat (ϕ_ini . unc , ϕunc_VI) # need to compare to MC sample
247259# hard to estimate for original very small theta's but otherwise good
248260
249261# test predicting correct obs-uncertainty of predictive posterior
250262# TODO reuse g_flux rather than g
251263n_sample_pred = 200
252- intm_PMs_gen = ComponentArrayInterpreter (CA. ComponentVector (; θP = θP_true,
253- θMs = CA. ComponentMatrix (
254- zeros (n_θM, n_site), first (CA. getaxes (θMs_true)), CA. Axis (i = 1 : n_sample_pred))))
255-
256- ζs, _ = HVI. generate_ζ (rng, g, f, res. u |> Flux. cpu, xM,
257- (; interpreters... , PMs = intm_PMs_gen); n_MC = n_sample_pred)
258- # ζ = ζs[:,1]
259- θsc = stack (
260- ζ -> CA. getdata (CA. ComponentVector (
261- TransformVariables. transform (transPMs_all, ζ))),
262- eachcol (ζs));
263- y_pred = stack (map (ζ -> first (HVI. predict_y (ζ, f, transPMs_all)), eachcol (ζs)));
264-
265- size (y_pred)
266- σ_o_post = mapslices (std, y_pred; dims = 3 )[:, :, 1 ];
264+ y_pred = predict_gf (rng, g_flux, f, res. u, xM_gpu, interpreters;
265+ get_transPMs, get_ca_int_PMs, n_sample_pred);
266+ size (y_pred) # n_obs x n_site, n_sample_pred
267+
268+ σ_o_post = dropdims (std (y_pred; dims = 3 ), dims= 3 )
269+
267270# describe(σ_o_post)
268271hcat (σ_o, fill (mean_σ_o_MC, length (σ_o)),
269272 mean (σ_o_post, dims = 2 ), sqrt .(mean (abs2, σ_o_post, dims = 2 )))
0 commit comments