@@ -32,7 +32,7 @@ cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity
3232
3333# ------ setup synthetic data and training data loader
3434prob0_ = HybridProblem (DoubleMM. DoubleMMCase (); scenario);
35- (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o , y_o, y_unc
35+ (; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc
3636) = gen_hybridproblem_synthetic (rng, DoubleMM. DoubleMMCase (); scenario);
3737n_site, n_batch = get_hybridproblem_n_site_and_batch (prob0_; scenario)
3838ζP_true, ζMs_true = log .(θP_true), log .(θMs_true)
@@ -59,7 +59,7 @@ n_epoch = 80
5959 maxiters = n_batches_in_epoch * n_epoch);
6060# update the problem with optimized parameters
6161prob0o = prob1o = probo;
62- y_pred_global, y_pred, θMs = gf (prob0o; scenario, is_inferred= Val (true ));
62+ y_pred, θMs = gf (prob0o; scenario, is_inferred= Val (true ));
6363# @descend_code_warntype gf(prob0o; scenario)
6464# @usingany UnicodePlots
6565plt = scatterplot (θMs_true' [:, 1 ], θMs[:, 1 ]);
@@ -77,7 +77,7 @@ histogram(vec(y_pred) - vec(y_true)) # predictions centered around y_o (or y_tru
7777 (; ϕ, resopt) = solve (prob0o, solver1; scenario, rng,
7878 callback = callback_loss (20 ), maxiters = 400 )
7979 prob1o = HybridProblem (prob0o; ϕg = cpu_ca (ϕ). ϕg, θP = cpu_ca (ϕ). θP)
80- y_pred_global, y_pred, θMs = gf (prob1o, xM, xP; scenario)
80+ y_pred, θMs = gf (prob1o, xM, xP; scenario)
8181 scatterplot (θMs_true[1 , :], θMs[1 , :])
8282 scatterplot (θMs_true[2 , :], θMs[2 , :])
8383 prob1o. θP
9191 (; ϕ, resopt) = solve (prob2, solver1; scenario, rng,
9292 callback = callback_loss (20 ), maxiters = 600 )
9393 prob2o = HybridProblem (prob2; ϕg = collect (ϕ. ϕg), θP = ϕ. θP)
94- y_pred_global, y_pred, θMs = gf (prob2o, xM, xP)
94+ y_pred, θMs = gf (prob2o, xM, xP)
9595 prob2o. θP
9696end
9797
127127 (; ϕ, resopt) = solve (prob3, solver1; scenario, rng,
128128 callback = callback_loss (50 ), maxiters = 600 )
129129 prob3o = HybridProblem (prob3; ϕg = cpu_ca (ϕ). ϕg, θP = cpu_ca (ϕ). θP)
130- y_pred_global, y_pred, θMs = gf (prob3o, xM, xP; scenario)
130+ y_pred, θMs = gf (prob3o, xM, xP; scenario)
131131 scatterplot (θMs_true[2 , :], θMs[2 , :])
132132 prob3o. θP
133133 scatterplot (vec (y_true), vec (y_pred))
@@ -173,7 +173,7 @@ solver_post = HybridPosteriorSolver(; alg = OptimizationOptimisers.Adam(0.01), n
173173 (y1, θsP1, θsMs1) = (y, θsP, θsMs);
174174
175175 () -> begin # prediction with fitted parameters (should be smaller than mean)
176- y_pred_global, y_pred2, θMs = gf (prob1o, xM, xP; scenario)
176+ y_pred2, θMs = gf (prob1o, xM, xP; scenario)
177177 scatterplot (θMs_true[1 , :], θMs[1 , :])
178178 scatterplot (θMs_true[2 , :], θMs[2 , :])
179179 hcat (θP_true, θP) # all parameters overestimated
366366 # ζMs = invt.transM.(θMs_i)
367367 # _f = get_hybridproblem_PBmodel(probo; scenario)
368368 # y_site = map(eachcol(θPs), θMs_i) do θP, θM
369- # y_global, y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]])
369+ # y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]])
370370 # y[:,1]
371371 # end |> stack
372372 nLs = get_hybridproblem_neg_logden_obs (
0 commit comments