@@ -11,49 +11,51 @@ using UnicodePlots
1111using SimpleChains
1212using Flux
1313using MLUtils
14+ using CUDA
1415
1516rng = StableRNG (114 )
1617scenario = NTuple {0, Symbol} ()
17- # scenario = (:use_Flux,)
18+ scenario = (:use_Flux ,)
1819
1920# ------ setup synthetic data and training data loader
2021(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
2122) = gen_hybridcase_synthetic (rng, DoubleMM. DoubleMMCase (); scenario);
23+ xM_cpu = xM
24+ if :use_Flux ∈ scenario
25+ xM = CuArray (xM_cpu)
26+ end
2227get_train_loader = (rng; n_batch, kwargs... ) -> MLUtils. DataLoader ((xM, xP, y_o, y_unc), batchsize = n_batch)
2328σ_o = exp (first (y_unc)/ 2 )
2429
2530# assign the train_loader, otherwise it eatch time creates another version of synthetic data
26- prob0 = update (HybridProblem (DoubleMM. DoubleMMCase (); scenario); get_train_loader)
31+ prob0 = HVI . update (HybridProblem (DoubleMM. DoubleMMCase (); scenario); get_train_loader)
2732
2833# ------- pointwise hybrid model fit
2934# solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
3035solver = HybridPointSolver (; alg = Adam (0.01 ), n_batch = 10 )
3136# solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
3237(; ϕ, resopt) = solve (prob0, solver; scenario,
3338 rng, callback = callback_loss (100 ), maxiters = 1200 );
34- prob0o = update (prob0; ϕg= ϕ. ϕg, θP= ϕ. θP)
35- y_pred_global, y_pred, θMs = gf (prob0o, xM, xP);
39+ # update the problem with optimized parameters
40+ prob0o = HVI. update (prob0; ϕg= cpu_ca (ϕ). ϕg, θP= cpu_ca (ϕ). θP)
41+ y_pred_global, y_pred, θMs = gf (prob0o, xM, xP; scenario);
3642scatterplot (θMs_true[1 ,:], θMs[1 ,:])
3743scatterplot (θMs_true[2 ,:], θMs[2 ,:])
3844
3945# do a few steps without minibatching,
4046# by providing the data rather than the DataLoader
41- # train_loader0 = get_hybridproblem_train_dataloader(rng, prob0; scenario, n_batch=1000)
42- # get_train_loader_data = (args...; kwargs...) -> train_loader0.data
43- # prob1 = update(prob0o; get_train_loader = get_train_loader_data)
44- prob1 = prob0o
45-
46- # solver1 = HybridPointSolver(; alg = Adam(0.05), n_batch = n_site)
4747solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
48- (; ϕ, resopt) = solve (prob1 , solver1; scenario, rng,
48+ (; ϕ, resopt) = solve (prob0o , solver1; scenario, rng,
4949 callback = callback_loss (20 ), maxiters = 600 );
50- prob1o = update (prob1 ; ϕg= ϕ . ϕg, θP= ϕ . θP)
51- y_pred_global, y_pred, θMs = gf (prob1o, xM, xP);
50+ prob1o = HVI . update (prob0o ; ϕg= cpu_ca (ϕ) . ϕg, θP= cpu_ca (ϕ) . θP);
51+ y_pred_global, y_pred, θMs = gf (prob1o, xM, xP; scenario );
5252scatterplot (θMs_true[1 ,:], θMs[1 ,:])
5353scatterplot (θMs_true[2 ,:], θMs[2 ,:])
5454prob1o. θP
5555scatterplot (vec (y_true), vec (y_pred))
5656
57+ # still overestimating θMs
58+
5759() -> begin # with more iterations?
5860 prob2 = prob1o
5961 (; ϕ, resopt) = solve (prob2, solver1; scenario, rng,
@@ -63,50 +65,55 @@ scatterplot(vec(y_true), vec(y_pred))
6365 prob2o. θP
6466end
6567
66- # ----------- fit g to true θMs
67- # and fit gf starting from true parameters
68- prob = prob0
69- g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario);
70- (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
71-
72- function loss_g (ϕg, x, g, transM)
73- ζMs = g (x, ϕg) # predict the log of the parameters
74- θMs = reduce (hcat, map (transM, eachcol (ζMs))) # transform each column
75- loss = sum (abs2, θMs .- θMs_true)
76- return loss, θMs
77- end
78- loss_g (ϕg0, xM, g, transM)
7968
80- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
81- Optimization. AutoZygote ())
82- optprob = Optimization. OptimizationProblem (optf, ϕg0);
83- res = Optimization. solve (optprob, Adam (0.015 ), callback = callback_loss (100 ), maxiters = 2000 );
84-
85- ϕg_opt1 = res. u;
86- l1, θMs = loss_g (ϕg_opt1, xM, g, transM)
87- # scatterplot(θMs_true[1,:], θMs[1,:])
88- scatterplot (θMs_true[2 ,:], θMs[2 ,:]) # able to fit θMs[2,:]
89-
90- prob3 = update (prob0, ϕg = ϕg_opt1, θP = θP_true)
91- solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
92- (; ϕ, resopt) = solve (prob3, solver1; scenario, rng,
93- callback = callback_loss (50 ), maxiters = 600 );
94- prob3o = update (prob3; ϕg= ϕ. ϕg, θP= ϕ. θP)
95- y_pred_global, y_pred, θMs = gf (prob3o, xM, xP);
96- scatterplot (θMs_true[2 ,:], θMs[2 ,:])
97- prob3o. θP
98- scatterplot (vec (y_true), vec (y_pred))
99- scatterplot (vec (y_true), vec (y_o))
100- scatterplot (vec (y_pred), vec (y_o))
69+ # ----------- fit g to true θMs
70+ () -> begin
71+ # and fit gf starting from true parameters
72+ prob = prob0
73+ g, ϕg0_cpu = get_hybridproblem_MLapplicator (prob; scenario);
74+ ϕg0 = (:use_Flux ∈ scenario) ? CuArray (ϕg0_cpu) : ϕg0_cpu
75+ (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
76+
77+ function loss_g (ϕg, x, g, transM; gpu_handler = HVI. default_GPU_DataHandler)
78+ ζMs = g (x, ϕg) # predict the log of the parameters
79+ ζMs_cpu = gpu_handler (ζMs)
80+ θMs = reduce (hcat, map (transM, eachcol (ζMs_cpu))) # transform each column
81+ loss = sum (abs2, θMs .- θMs_true)
82+ return loss, θMs
83+ end
84+ loss_g (ϕg0, xM, g, transM)
85+
86+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g, transM)[1 ],
87+ Optimization. AutoZygote ())
88+ optprob = Optimization. OptimizationProblem (optf, ϕg0);
89+ res = Optimization. solve (optprob, Adam (0.015 ), callback = callback_loss (100 ), maxiters = 2000 );
90+
91+ ϕg_opt1 = res. u;
92+ l1, θMs = loss_g (ϕg_opt1, xM, g, transM)
93+ # scatterplot(θMs_true[1,:], θMs[1,:])
94+ scatterplot (θMs_true[2 ,:], θMs[2 ,:]) # able to fit θMs[2,:]
95+
96+ prob3 = HVI. update (prob0, ϕg = Array (ϕg_opt1), θP = θP_true)
97+ solver1 = HybridPointSolver (; alg = Adam (0.01 ), n_batch = n_site)
98+ (; ϕ, resopt) = solve (prob3, solver1; scenario, rng,
99+ callback = callback_loss (50 ), maxiters = 600 );
100+ prob3o = HVI. update (prob3; ϕg= cpu_ca (ϕ). ϕg, θP= cpu_ca (ϕ). θP)
101+ y_pred_global, y_pred, θMs = gf (prob3o, xM, xP; scenario);
102+ scatterplot (θMs_true[2 ,:], θMs[2 ,:])
103+ prob3o. θP
104+ scatterplot (vec (y_true), vec (y_pred))
105+ scatterplot (vec (y_true), vec (y_o))
106+ scatterplot (vec (y_pred), vec (y_o))
101107
102- () -> begin # optimized loss is indeed lower than with true parameters
103- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
104- ϕg = 1 : length (prob0. ϕg), θP = prob0. θP))
105- loss_gf = get_loss_gf (prob0. g, prob0. transM, prob0. f, Float32[], int_ϕθP)
106- loss_gf (vcat (prob3. ϕg, prob3. θP), xM, xP, y_o, y_unc)[1 ]
107- loss_gf (vcat (prob3o. ϕg, prob3o. θP), xM, xP, y_o, y_unc)[1 ]
108- #
109- loss_gf (vcat (prob2o. ϕg, prob2o. θP), xM, xP, y_o, y_unc)[1 ]
108+ () -> begin # optimized loss is indeed lower than with true parameters
109+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
110+ ϕg = 1 : length (prob0. ϕg), θP = prob0. θP))
111+ loss_gf = get_loss_gf (prob0. g, prob0. transM, prob0. f, Float32[], int_ϕθP)
112+ loss_gf (vcat (prob3. ϕg, prob3. θP), xM, xP, y_o, y_unc)[1 ]
113+ loss_gf (vcat (prob3o. ϕg, prob3o. θP), xM, xP, y_o, y_unc)[1 ]
114+ #
115+ loss_gf (vcat (prob2o. ϕg, prob2o. θP), xM, xP, y_o, y_unc)[1 ]
116+ end
110117end
111118
112119# ----------- Hybrid Variational inference
0 commit comments