@@ -14,42 +14,38 @@ using OptimizationOptimisers
1414
1515const MLengine = Val (nameof (SimpleChains))
1616
17-
1817construct_problem = () -> begin
19- S1 = [1.0 , 1.0 , 1.0 , 1.0 , 0.4 , 0.3 , 0.1 ]
20- S2 = [1.0 , 3.0 , 4.0 , 5.0 , 5.0 , 5.0 , 5.0 ]
2118 θP = CA. ComponentVector {Float32} (r0 = 0.3 , K2 = 2.0 )
22- θM = CA. ComponentVector {Float32} (r1 = 0.5 , K1 = 0.2 )
19+ θM = CA. ComponentVector {Float32} (r1 = 0.5 , K1 = 0.2 )
2320 transP = elementwise (exp)
2421 transM = Stacked (elementwise (identity), elementwise (exp))
2522 n_covar = 5
2623 n_batch = 10
2724 int_θdoubleMM = get_concrete (ComponentArrayInterpreter (
2825 flatten1 (CA. ComponentVector (; θP, θM))))
29- function f_doubleMM (θ:: AbstractVector )
26+ function f_doubleMM (θ:: AbstractVector , x )
3027 # extract parameters not depending on order, i.e whether they are in θP or θM
3128 θc = int_θdoubleMM (θ)
3229 r0, r1, K1, K2 = θc[(:r0 , :r1 , :K1 , :K2 )]
33- y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
30+ y = r0 .+ r1 .* x . S1 ./ (K1 .+ x . S1) .* x . S2 ./ (K2 .+ x . S2)
3431 return (y)
3532 end
36- fsite = (θ, x_site) -> f_doubleMM (θ) # omit x_site drivers
3733 function f_doubleMM_with_global (θP:: AbstractVector , θMs:: AbstractMatrix , x)
38- pred_sites = applyf (fsite , θMs, θP, x)
34+ pred_sites = applyf (f_doubleMM , θMs, θP, x)
3935 pred_global = eltype (pred_sites)[]
4036 return pred_global, pred_sites
41- end
37+ end
4238 n_out = length (θM)
4339 g_chain = SimpleChain (
44- static (n_covar), # input dimension (optional)
45- # dense layer with bias that maps to 8 outputs and applies `tanh` activation
46- TurboDense {true} (tanh, n_covar * 4 ),
47- TurboDense {true} (tanh, n_covar * 4 ),
48- # dense layer without bias that maps to n outputs and `identity` activation
49- TurboDense {false} (identity, n_out),
50- )
40+ static (n_covar), # input dimension (optional)
41+ # dense layer with bias that maps to 8 outputs and applies `tanh` activation
42+ TurboDense {true} (tanh, n_covar * 4 ),
43+ TurboDense {true} (tanh, n_covar * 4 ),
44+ # dense layer without bias that maps to n outputs and `identity` activation
45+ TurboDense {false} (identity, n_out)
46+ )
5147 g = construct_SimpleChainsApplicator (g_chain)
52- ϕg = SimpleChains. init_params (g_chain, eltype (θM));
48+ ϕg = SimpleChains. init_params (g_chain, eltype (θM))
5349 HybridProblem (θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, g, ϕg)
5450end
5551prob = construct_problem ();
@@ -65,7 +61,7 @@ rng = StableRNG(111)
6561) = gen_hybridcase_synthetic (case_syn, rng; scenario);
6662
6763@testset " loss_g" begin
68- g, ϕg0 = get_hybridcase_MLapplicator (prob, MLengine; scenario);
64+ g, ϕg0 = get_hybridcase_MLapplicator (prob, MLengine; scenario)
6965
7066 function loss_g (ϕg, x, g)
7167 ζMs = g (x, ϕg) # predict the log of the parameters
@@ -74,15 +70,15 @@ rng = StableRNG(111)
7470 return loss, θMs
7571 end
7672 loss_g (ϕg0, xM, g)
77- Zygote. gradient (x -> loss_g (x, xM, g)[1 ], ϕg0);
73+ Zygote. gradient (x -> loss_g (x, xM, g)[1 ], ϕg0)
7874
7975 optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM, g)[1 ],
8076 Optimization. AutoZygote ())
81- optprob = Optimization. OptimizationProblem (optf, ϕg0);
77+ optprob = Optimization. OptimizationProblem (optf, ϕg0)
8278 # res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
83- res = Optimization. solve (optprob, Adam (0.02 ), maxiters = 600 );
79+ res = Optimization. solve (optprob, Adam (0.02 ), maxiters = 600 )
8480
85- ϕg_opt1 = res. u;
81+ ϕg_opt1 = res. u
8682 pred = loss_g (ϕg_opt1, xM, g)
8783 θMs_pred = pred[2 ]
8884 # scatterplot(vec(θMs_true), vec(θMs_pred))
9187
9288@testset " loss_gf" begin
9389 # ----------- fit g and θP to y_o
94- g, ϕg0 = get_hybridcase_MLapplicator (prob, MLengine; scenario);
90+ g, ϕg0 = get_hybridcase_MLapplicator (prob, MLengine; scenario)
9591 f = get_hybridcase_PBmodel (prob; scenario)
9692
9793 int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
9894 ϕg = 1 : length (ϕg0), θP = par_templates. θP))
99- p = p0 = vcat (ϕg0, par_templates. θP .* 0.8 ); # slightly disturb θP_true
95+ p = p0 = vcat (ϕg0, par_templates. θP .* 0.8 ) # slightly disturb θP_true
10096
10197 # Pass the site-data for the batches as separate vectors wrapped in a tuple
10298 train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
109105 optprob = OptimizationProblem (optf, p0, train_loader)
110106
111107 res = Optimization. solve (
112- # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
113- optprob, Adam (0.02 ), maxiters = 1000 );
108+ # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
109+ optprob, Adam (0.02 ), maxiters = 1000 )
114110
115111 l1, y_pred_global, y_pred, θMs_pred = loss_gf (res. u, train_loader. data... )
116112 @test isapprox (par_templates. θP, int_ϕθP (res. u). θP, rtol = 0.11 )
0 commit comments