@@ -7,7 +7,7 @@ using Statistics
77using ComponentArrays: ComponentArrays as CA
88
99using SimpleChains
10- import Flux # to allow for FluxMLEngine and cpu()
10+ import Flux
1111using MLUtils
1212import Zygote
1313
@@ -17,20 +17,22 @@ using Bijectors
1717using UnicodePlots
1818
1919const case = DoubleMM. DoubleMMCase ()
20- const MLengine = Val (nameof (SimpleChains))
21- const FluxMLengine = Val (nameof (Flux))
2220scenario = (:default ,)
2321rng = StableRNG (111 )
2422
2523par_templates = get_hybridcase_par_templates (case; scenario)
2624
27- (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
25+ # n_covar = get_hybridcase_n_covar(case; scenario)
26+ # , n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
2827
2928(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
3029) = gen_hybridcase_synthetic (rng, case; scenario);
3130
31+ n_covar = size (xM,1 )
32+
33+
3234# ----- fit g to θMs_true
33- g, ϕg0 = get_hybridcase_MLapplicator (case, MLengine ; scenario);
35+ g, ϕg0 = get_hybridcase_MLapplicator (case; scenario);
3436(; transP, transM) = get_hybridcase_transforms (case; scenario)
3537
3638function loss_g (ϕg, x, g, transM)
@@ -90,6 +92,8 @@ FT = get_hybridcase_float_type(case; scenario)
9092 θP_true, θMs_true[:, 1 ], ϕg_opt1, n_batch; transP, transM);
9193ϕ_true = ϕ
9294
95+
96+
9397() -> begin
9498 coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
9599 logσ2_logP = CA. ComponentVector (r0 = - 8.997 , K2 = - 5.893 )
@@ -162,7 +166,8 @@ mean_σ_o_MC = 0.006042
162166
163167ϕ = CA. getdata (ϕ_ini) |> Flux. gpu;
164168xM_gpu = xM |> Flux. gpu;
165- g_flux, _ = get_hybridcase_MLapplicator (case, FluxMLengine; scenario);
169+ scenario_flux = (scenario... , :use_Flux )
170+ g_flux, _ = get_hybridcase_MLapplicator (case; scenario = scenario_flux);
166171
167172# otpimize using LUX
168173() -> begin
@@ -200,7 +205,7 @@ gr = Zygote.gradient(fcost,
200205gr_c = CA. ComponentArray (gr[1 ] |> Flux. cpu, CA. getaxes (ϕ_ini)... )
201206
202207train_loader = MLUtils. DataLoader ((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
203- # train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux ))
208+ # train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux ))
204209
205210optf = Optimization. OptimizationFunction (
206211 (ϕ, data) -> begin
0 commit comments