@@ -46,76 +46,50 @@ construct_problem = () -> begin
4646 )
4747 g = construct_SimpleChainsApplicator (g_chain)
4848 ϕg = SimpleChains. init_params (g_chain, eltype (θM))
49- HybridProblem (θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, g, ϕg)
49+ #
50+ rng = StableRNG (111 )
51+ (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
52+ ) = gen_hybridcase_synthetic (DoubleMM. DoubleMMCase (), rng;);
53+ train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
54+ HybridProblem (θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global,
55+ g, ϕg, train_loader)
5056end
5157prob = construct_problem ();
52- case_syn = DoubleMM. DoubleMMCase ()
5358scenario = (:default ,)
5459
55- par_templates = get_hybridcase_par_templates (prob; scenario)
56-
57- (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes (prob; scenario)
58-
59- rng = StableRNG (111 )
60- (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
61- ) = gen_hybridcase_synthetic (case_syn, rng; scenario);
62-
63- @testset " loss_g" begin
64- g, ϕg0 = get_hybridcase_MLapplicator (prob, MLengine; scenario)
65-
66- function loss_g (ϕg, x, g)
67- ζMs = g (x, ϕg) # predict the log of the parameters
68- θMs = exp .(ζMs)
69- loss = sum (abs2, θMs .- θMs_true)
70- return loss, θMs
71- end
72- loss_g (ϕg0, xM, g)
73- Zygote. gradient (x -> loss_g (x, xM, g)[1 ], ϕg0)
7460
75- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g)[1 ],
76- Optimization. AutoZygote ())
77- optprob = Optimization. OptimizationProblem (optf, ϕg0)
78- # res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
79- res = Optimization. solve (optprob, Adam (0.02 ), maxiters = 600 )
80-
81- ϕg_opt1 = res. u
82- pred = loss_g (ϕg_opt1, xM, g)
83- θMs_pred = pred[2 ]
84- # scatterplot(vec(θMs_true), vec(θMs_pred))
85- @test cor (vec (θMs_true), vec (θMs_pred)) > 0.9
86- end
61+ # (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(prob; scenario)
8762
8863@testset " loss_gf" begin
8964 # ----------- fit g and θP to y_o
9065 g, ϕg0 = get_hybridcase_MLapplicator (prob, MLengine; scenario)
66+ train_loader = get_hybridcase_train_dataloader (prob; scenario)
67+ (xM, xP, y_o) = first (train_loader)
9168 f = get_hybridcase_PBmodel (prob; scenario)
69+ par_templates = get_hybridcase_par_templates (prob; scenario)
9270
9371 int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
9472 ϕg = 1 : length (ϕg0), θP = par_templates. θP))
9573 p = p0 = vcat (ϕg0, par_templates. θP .* 0.8 ) # slightly disturb θP_true
9674
9775 # Pass the site-data for the batches as separate vectors wrapped in a tuple
98- train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
9976
77+ y_global_o = Float64[]
10078 loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
101- l1 = loss_gf (p0, train_loader. data... )[1 ]
79+ l1 = loss_gf (p0, first (train_loader)... )
80+ gr = Zygote. gradient (p -> loss_gf (p, train_loader. data... )[1 ], p0)
81+ @test gr[1 ] isa Vector
10282
103- optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
104- Optimization. AutoZygote ())
105- optprob = OptimizationProblem (optf, p0, train_loader)
106-
107- res = Optimization. solve (
108- # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
109- optprob, Adam (0.02 ), maxiters = 1000 )
83+ () -> begin
84+ optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
85+ Optimization. AutoZygote ())
86+ optprob = OptimizationProblem (optf, p0, train_loader)
11087
111- l1, y_pred_global, y_pred, θMs_pred = loss_gf (res . u, train_loader . data ... )
112- @test isapprox (par_templates . θP, int_ϕθP (res . u) . θP, rtol = 0.11 )
113- @test cor ( vec (θMs_true), vec (θMs_pred)) > 0.9
88+ res = Optimization . solve (
89+ # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
90+ optprob, Adam ( 0.02 ), maxiters = 1000 )
11491
115- () -> begin
116- scatterplot (vec (θMs_true), vec (θMs_pred))
117- scatterplot (log .(vec (θMs_true)), log .(vec (θMs_pred)))
118- scatterplot (vec (y_pred), vec (y_o))
119- hcat (par_templates. θP, int_ϕθP (p0). θP, int_ϕθP (res. u). θP)
92+ l1, y_pred_global, y_pred, θMs_pred = loss_gf (res. u, train_loader. data... )
93+ @test isapprox (par_templates. θP, int_ϕθP (res. u). θP, rtol = 0.11 )
12094 end
12195end
0 commit comments