@@ -5,31 +5,123 @@ using StableRNGs
55using Random
66using Statistics
77using ComponentArrays: ComponentArrays as CA
8-
8+ using Optimization
9+ using OptimizationOptimisers # Adam
10+ using UnicodePlots
911using SimpleChains
10- import Flux
12+ using Flux
13+ using MLUtils
14+
15+ rng = StableRNG (114 )
16+ scenario = NTuple {0, Symbol} ()
17+ # scenario = (:use_Flux,)
18+
19+ # ------ setup synthetic data and training data loader
20+ (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
21+ ) = gen_hybridcase_synthetic (rng, DoubleMM. DoubleMMCase (); scenario);
22+ get_train_loader = (rng; n_batch, kwargs... ) -> MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
23+ σ_o = exp (first (y_unc)/ 2 )
24+
25+ # assign the train_loader, otherwise it eatch time creates another version of synthetic data
26+ prob0 = update (HybridProblem (DoubleMM. DoubleMMCase (); scenario); get_train_loader)
27+
28+ # ------- pointwise hybrid model fit
29+ # solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
30+ solver = HybridPointSolver (; alg = Adam (0.01 ), n_batch = 10 )
31+ # solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
32+ (; ϕ, resopt) = solve (prob0, solver; scenario,
33+ rng, callback = callback_loss (100 ), maxiters = 1200 );
34+ prob0o = update (prob0; ϕg= ϕ. ϕg, θP= ϕ. θP)
35+ y_pred_global, y_pred, θMs = gf (prob0o, xM, xP);
36+ scatterplot (θMs_true[1 ,:], θMs[1 ,:])
37+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
38+
39+ # do a few steps without minibatching,
40+ # by providing the data rather than the DataLoader
41+ # train_loader0 = get_hybridproblem_train_dataloader(rng, prob0; scenario, n_batch=1000)
42+ # get_train_loader_data = (args...; kwargs...) -> train_loader0.data
43+ # prob1 = update(prob0o; get_train_loader = get_train_loader_data)
44+ prob1 = prob0o
45+
46+ # solver1 = HybridPointSolver(; alg = Adam(0.05), n_batch = n_site)
47+ solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
48+ (; ϕ, resopt) = solve (prob1, solver1; scenario, rng,
49+ callback = callback_loss (20 ), maxiters = 600 );
50+ prob1o = update (prob1; ϕg= ϕ. ϕg, θP= ϕ. θP)
51+ y_pred_global, y_pred, θMs = gf (prob1o, xM, xP);
52+ scatterplot (θMs_true[1 ,:], θMs[1 ,:])
53+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
54+ prob1o. θP
55+ scatterplot (vec (y_true), vec (y_pred))
56+
57+ () -> begin # with more iterations?
58+ prob2 = prob1o
59+ (; ϕ, resopt) = solve (prob2, solver1; scenario, rng,
60+ callback = callback_loss (20 ), maxiters = 600 );
61+ prob2o = update (prob2; ϕg= ϕ. ϕg, θP= ϕ. θP)
62+ y_pred_global, y_pred, θMs = gf (prob2o, xM, xP);
63+ prob2o. θP
64+ end
65+
66+ # ----------- fit g to true θMs
67+ # and fit gf starting from true parameters
68+ prob = prob0
69+ g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
70+ (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
71+
72+ function loss_g (ϕg, x, g, transM)
73+ ζMs = g (x, ϕg) # predict the log of the parameters
74+ θMs = reduce (hcat, map (transM, eachcol (ζMs))) # transform each column
75+ loss = sum (abs2, θMs .- θMs_true)
76+ return loss, θMs
77+ end
78+ loss_g (ϕg0, xM, g, transM)
79+
80+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
81+ Optimization. AutoZygote ())
82+ optprob = Optimization. OptimizationProblem (optf, ϕg0);
83+ res = Optimization. solve (optprob, Adam (0.015 ), callback = callback_loss (100 ), maxiters = 2000 );
84+
85+ ϕg_opt1 = res. u;
86+ l1, θMs = loss_g (ϕg_opt1, xM, g, transM)
87+ # scatterplot(θMs_true[1,:], θMs[1,:])
88+ scatterplot (θMs_true[2 ,:], θMs[2 ,:]) # able to fit θMs[2,:]
89+
90+ prob3 = update (prob0, ϕg = ϕg_opt1, θP = θP_true)
91+ solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
92+ (; ϕ, resopt) = solve (prob3, solver1; scenario, rng,
93+ callback = callback_loss (50 ), maxiters = 600 );
94+ prob3o = update (prob3; ϕg= ϕ. ϕg, θP= ϕ. θP)
95+ y_pred_global, y_pred, θMs = gf (prob3o, xM, xP);
96+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
97+ prob3o. θP
98+ scatterplot (vec (y_true), vec (y_pred))
99+ scatterplot (vec (y_true), vec (y_o))
100+ scatterplot (vec (y_pred), vec (y_o))
101+
102+ () -> begin # optimized loss is indeed lower than with true parameters
103+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
104+ ϕg = 1 : length (prob0. ϕg), θP = prob0. θP))
105+ loss_gf = get_loss_gf (prob0. g, prob0. transM, prob0. f, Float32[], int_ϕθP)
106+ loss_gf (vcat (prob3. ϕg, prob3. θP), xM, xP, y_o, y_unc)[1 ]
107+ loss_gf (vcat (prob3o. ϕg, prob3o. θP), xM, xP, y_o, y_unc)[1 ]
108+ #
109+ loss_gf (vcat (prob2o. ϕg, prob2o. θP), xM, xP, y_o, y_unc)[1 ]
110+ end
111+
112+ # ----------- Hybrid Variational inference
113+
11114using MLUtils
12115import Zygote
13116
14117using CUDA
15- using OptimizationOptimisers
16118using Bijectors
17- using UnicodePlots
18119
19- const prob = DoubleMM. DoubleMMCase ()
20- scenario = (:default ,)
21- rng = StableRNG (111 )
22-
23- par_templates = get_hybridproblem_par_templates (prob; scenario)
24120
25121# n_covar = get_hybridproblem_n_covar(prob; scenario)
26122# , n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
27123
28- (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29- ) = gen_hybridcase_synthetic (rng, prob; scenario);
30-
31- n_covar = size (xM,1 )
32-
124+ n_covar = size (xM, 1 )
33125
34126# ----- fit g to θMs_true
35127g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
@@ -92,8 +184,6 @@ FT = get_hybridproblem_float_type(prob; scenario)
92184 θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP, transM);
93185ϕ_true = ϕ
94186
95-
96-
97187() -> begin
98188 coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
99189 logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
@@ -245,7 +335,7 @@ y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
245335size (y_pred) # n_obs x n_site, n_sample_pred
246336
247337σ_o_post = dropdims (std (y_pred; dims = 3 ), dims = 3 );
248- σ_o = exp .(y_unc[:,1 ] / 2 )
338+ σ_o = exp .(y_unc[:, 1 ] / 2 )
249339
250340# describe(σ_o_post)
251341hcat (σ_o, fill (mean_σ_o_MC, length (σ_o)),
0 commit comments