@@ -17,35 +17,38 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve
1717 g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario)
1818 FT = get_hybridproblem_float_type (prob; scenario)
1919 (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
20- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
21- ϕg = 1 : length (ϕg0), θP = par_templates. θP))
22- # p0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true
23- p0_cpu = vcat (ϕg0, par_templates. θP)
24- p0 = p0_cpu
25- g_dev = g
20+ intϕ = ComponentArrayInterpreter (CA. ComponentVector (
21+ ϕg = 1 : length (ϕg0), ϕP = par_templates. θP))
22+ # ϕ0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true
23+ ϕ0_cpu = vcat (ϕg0, apply_preserve_axes (inverse (transP),par_templates. θP))
2624 if gdev isa MLDataDevices. AbstractGPUDevice
27- p0 = gdev (p0_cpu )
25+ ϕ0_dev = gdev (ϕ0_cpu )
2826 g_dev = gdev (g)
27+ else
28+ ϕ0_dev = ϕ0_cpu
29+ g_dev = g
2930 end
3031 train_loader = get_hybridproblem_train_dataloader (
3132 prob; scenario, n_batch = solver. n_batch)
3233 f = get_hybridproblem_PBmodel (prob; scenario)
3334 y_global_o = FT[] # TODO
34- loss_gf = get_loss_gf (g_dev, transM, f, y_global_o, int_ϕθP; cdev)
35+ pbm_covars = get_hybridproblem_pbmpar_covars (prob; scenario)
36+ # intP = ComponentArrayInterpreter(par_templates.θP)
37+ loss_gf = get_loss_gf (g_dev, transM, transP, f, y_global_o, intϕ; cdev, pbm_covars)
3538 # call loss function once
36- l1 = loss_gf (p0 , first (train_loader)... )[1 ]
39+ l1 = loss_gf (ϕ0_dev , first (train_loader)... )[1 ]
3740 # and gradient
3841 # xMg, xP, y_o, y_unc = first(train_loader)
3942 # gr1 = Zygote.gradient(
4043 # p -> loss_gf(p, xMg, xP, y_o, y_unc)[1],
41- # p0 )
44+ # ϕ0_dev )
4245 # data1 = first(train_loader)
43- # Zygote.gradient(p0 -> loss_gf(p0 , data1...)[1], p0 )
46+ # Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev , data1...)[1], ϕ0_dev )
4447 optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
4548 Optimization. AutoZygote ())
46- optprob = OptimizationProblem (optf, CA. getdata (p0 ), train_loader)
49+ optprob = OptimizationProblem (optf, CA. getdata (ϕ0_dev ), train_loader)
4750 res = Optimization. solve (optprob, solver. alg; kwargs... )
48- (; ϕ = int_ϕθP (res. u), resopt = res)
51+ (; ϕ = intϕ (res. u), resopt = res)
4952end
5053
5154struct HybridPosteriorSolver{A} <: AbstractHybridSolver
@@ -77,6 +80,7 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
7780 g, ϕg0 = get_hybridproblem_MLapplicator (prob; scenario)
7881 ϕunc0 = get_hybridproblem_ϕunc (prob; scenario)
7982 (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
83+ pbm_covars = get_hybridproblem_pbmpar_covars (prob; scenario)
8084 (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params (
8185 θP, θM, cor_ends, ϕg0, solver. n_batch; transP, transM, ϕunc0)
8286 if gdev isa MLDataDevices. AbstractGPUDevice
@@ -90,12 +94,12 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
9094 f = get_hybridproblem_PBmodel (prob; scenario)
9195 py = get_hybridproblem_neg_logden_obs (prob; scenario)
9296 priors_θ_mean = construct_priors_θ_mean (
93- prob, ϕ0_dev. ϕg, keys (θM), θP, θmean_quant, g_dev, transM;
94- scenario, get_ca_int_PMs, cdev)
97+ prob, ϕ0_dev. ϕg, keys (θM), θP, θmean_quant, g_dev, transM, transP ;
98+ scenario, get_ca_int_PMs, cdev, pbm_covars )
9599 y_global_o = Float32[] # TODO
96100 loss_elbo = get_loss_elbo (
97101 g_dev, transPMs_batch, f, py, y_global_o, interpreters;
98- solver. n_MC, solver. n_MC_cap, cor_ends, priors_θ_mean, cdev)
102+ solver. n_MC, solver. n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covars, θP )
99103 # test loss function once
100104 l0 = loss_elbo (ϕ0_dev, rng, first (train_loader)... )
101105 optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_elbo (ϕ, rng, data... )[1 ],
@@ -116,28 +120,32 @@ end
116120
117121"""
118122Create a loss function for parameter vector ϕ, given
119- - g(x, ϕ): machine learning model
120- - transPMS: transformation from unconstrained space to parameter space
121- - f(θMs, θP): mechanistic model
122- - interpreters: assigning structure to pure vectors, see neg_elbo_gtf
123- - n_MC: number of Monte-Carlo sample to approximate the expected value across distribution
123+ - `g(x, ϕ)`: machine learning model
124+ - `transPMS`: transformation from unconstrained space to parameter space
125+ - `f(θMs, θP)`: mechanistic model
126+ - `interpreters`: assigning structure to pure vectors, see `neg_elbo_gtf`
127+ - `n_MC`: number of Monte-Carlo sample to approximate the expected value across distribution
128+ - `pbm_covars`: tuple of symbols of process-based parameters provided to the ML model
129+ - `θP`: CompoenntVector as a template to select indices of pbm_covars
124130
125131The loss function takes in addition to ϕ, data that changes with minibatch
126- - rng: random generator
127- - xM : matrix of covariates, sites in columns
128- - xP : drivers for the processmodel: Iterator of size n_site
129- - y_o, y_unc: matrix of observations and uncertainties, sites in columns
132+ - ` rng` : random generator
133+ - `xM` : matrix of covariates, sites in columns
134+ - `xP` : drivers for the processmodel: Iterator of size n_site
135+ - ` y_o`, ` y_unc` : matrix of observations and uncertainties, sites in columns
130136"""
131137function get_loss_elbo (g, transPMs, f, py, y_o_global, interpreters;
132- n_MC, n_MC_cap = n_MC, cor_ends, priors_θ_mean, cdev)
138+ n_MC, n_MC_cap = n_MC, cor_ends, priors_θ_mean, cdev, pbm_covars, θP,
139+ )
133140 let g = g, transPMs = transPMs, f = f, py = py, y_o_global = y_o_global, n_MC = n_MC,
134141 cor_ends = cor_ends, interpreters = map (get_concrete, interpreters),
135- priors_θ_mean = priors_θ_mean, cdev = cdev
142+ priors_θ_mean = priors_θ_mean, cdev = cdev,
143+ pbm_covar_indices = get_pbm_covar_indices (θP, pbm_covars)
136144
137145 function loss_elbo (ϕ, rng, xM, xP, y_o, y_unc, i_sites)
138146 neg_elbo_gtf (
139147 rng, ϕ, g, transPMs, f, py, xM, xP, y_o, y_unc, i_sites, interpreters;
140- n_MC, n_MC_cap, cor_ends, priors_θ_mean, cdev)
148+ n_MC, n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covar_indices )
141149 end
142150 end
143151end
@@ -183,16 +191,19 @@ end
183191In order to let mean of θ stay close to initial point parameter estimates
184192construct a prior on mean θ to a Normal around initial prediction.
185193"""
186- function construct_priors_θ_mean (prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM;
187- scenario, get_ca_int_PMs, cdev)
194+ function construct_priors_θ_mean (prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM, transP ;
195+ scenario, get_ca_int_PMs, cdev, pbm_covars )
188196 iszero (θmean_quant) ? [] :
189197 begin
190198 n_site = get_hybridproblem_n_site (prob; scenario)
191199 all_loader = get_hybridproblem_train_dataloader (prob; scenario, n_batch = n_site)
192200 xM_all = first (all_loader)[1 ]
193- θMs = gtrans (g_dev, transM, xM_all, CA. getdata (ϕg); cdev)
194- priors_dict = get_hybridproblem_priors (prob; scenario)
195201 # Main.@infiltrate_main
202+ ζP = apply_preserve_axes (inverse (transP), θP)
203+ pbm_covar_indices = get_pbm_covar_indices (θP, pbm_covars)
204+ xMP_all = _append_each_covars (xM_all, CA. getdata (ζP), pbm_covar_indices)
205+ θMs = gtrans (g_dev, transM, xMP_all, CA. getdata (ϕg); cdev)
206+ priors_dict = get_hybridproblem_priors (prob; scenario)
196207 priorsP = [priors_dict[k] for k in keys (θP)]
197208 priors_θP_mean = map (priorsP, θP) do priorsP, θPi
198209 fit_narrow_normal (θPi, priorsP, θmean_quant)
0 commit comments