@@ -141,21 +141,40 @@ end
141141end
142142
143143"""
144- predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters;
144+ predict_hvi([rng], prob::AbstractHybridProblem [,xM, xP]; scenario, ...)
145+ predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix;
145146 get_transPMs, get_ca_int_PMs, n_sample_pred=200, gdev = identity)
146147
147- Prediction function for hybrid model. Returns an NamedTuple with entries
148- - `θ`: ComponentArray `(n_θP + n_site * n_θM), n_sample_pred)` of PBM model parameters.
148+ Prediction function for hybrid variational inference parameter model.
149+
150+ ## Arguments
151+ - The problem for which to predict
152+ - xM: covariates for the machine-learning model (ML): Matrix (n_θM x n_site_pred).
153+ - xP: model drivers for process based model (PBM): Matrix with (n_site_pred) rows.
154+ If provided a ComponentArray with a Tuple-Axis in rows, the PBM model can
155+ access parts of it, e.g. `xP[:S1,...]`.
156+
157+ ## Keyword arguments
158+ - scenario
159+ - n_sample_pred
160+
161+ Returns an NamedTuple `(; y, θsP, θsMs, entropy_ζ)` with entries
149162- `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions.
163+ - `θsP`: ComponentArray `(n_θP, n_sample_pred)` of PBM model parameters
164+ that are kept constant across sites.
165+ - `θsMs`: ComponentArray `(n_site, n_θM, n_sample_pred)` of PBM model parameters
166+ that vary by site.
167+ - `entropy_ζ`: The entroy of the log-determinant of the transformation of
168+ the set of model parameters, which is involved in uncertainty quantification.
150169"""
151- function predict_gf (rng, prob:: AbstractHybridProblem ; scenario, kwargs... )
170+ function predict_hvi (rng, prob:: AbstractHybridProblem ; scenario, kwargs... )
152171 dl = get_hybridproblem_train_dataloader(prob; scenario)
153172 dl_dev = gdev_hybridproblem_dataloader(dl; scenario)
154173 # predict for all sites
155174 xM, xP = dl_dev. data[1 : 2 ]
156- predict_gf (rng, prob, xM, xP; scenario, kwargs... )
175+ predict_hvi (rng, prob, xM, xP; scenario, kwargs... )
157176end
158- function predict_gf (rng, prob:: AbstractHybridProblem , xM:: AbstractMatrix , xP;
177+ function predict_hvi (rng, prob:: AbstractHybridProblem , xM:: AbstractMatrix , xP;
159178 scenario,
160179 n_sample_pred= 200 ,
161180 gdev= :use_gpu ∈ _val_value(scenario) ? gpu_device() : identity,
@@ -184,12 +203,12 @@ function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
184203 int_unc = interpreters. unc
185204 transMs = StackedArray(transM, n_batch)
186205 g_dev, ϕ_dev = gdev(g), gdev(ϕ)
187- predict_gf (rng, g_dev, f, ϕ_dev, xM, xP;
206+ predict_hvi (rng, g_dev, f, ϕ_dev, xM, xP;
188207 int_μP_ϕg_unc, int_unc, transP, transM,
189208 n_sample_pred, cdev, cor_ends, pbm_covar_indices, kwargs... )
190209end
191210
192- function predict_gf (rng, g, f, ϕ:: AbstractVector , xM:: AbstractMatrix , xP;
211+ function predict_hvi (rng, g, f, ϕ:: AbstractVector , xM:: AbstractMatrix , xP;
193212 int_μP_ϕg_unc:: AbstractComponentArrayInterpreter ,
194213 int_unc:: AbstractComponentArrayInterpreter ,
195214 transP, transM,
@@ -274,6 +293,8 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT;
274293 # first pass: append μ_ζP_to covars, need ML prediction for magnitude of ζMs
275294 # TODO replace pbm_covar_indices by ComponentArray? dimensions to be type-inferred?
276295 xMP0 = _append_each_covars(xM, CA. getdata(μ_ζP), pbm_covar_indices)
296+ # Main.@infiltrate_main
297+
277298 μ_ζMs0 = g(xMP0, ϕg):: MT # for gpu restructure returns Any, so apply type
278299 ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc. unc; n_MC, cor_ends, int_unc)
279300 if pbm_covar_indices isa SA. SVector{0 }
@@ -362,7 +383,7 @@ ML-model predcitions of size `(n_θM, n_site)`.
362383
363384## Arguments
364385* `int_unc`: Interpret vector as ComponentVector with components
365- ρsP, ρsM, logσ2_logP , coef_logσ2_ζMs(intercept + slope),
386+ ρsP, ρsM, logσ2_ζP , coef_logσ2_ζMs(intercept + slope),
366387"""
367388function sample_ζresid_norm(rng:: Random.AbstractRNG , ζP:: AbstractVector , ζMs:: AbstractMatrix ,
368389 args... ; n_MC, cor_ends, int_unc)
@@ -392,10 +413,10 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM,
392413 UM = transformU_block_cholesky1(ρsM, cor_ends. M)
393414 cf = ϕuncc. coef_logσ2_ζMs
394415 logσ2_logMs = vec(cf[1 , :] .+ cf[2 , :] .* ζMs)
395- logσ2_logP = vec(CA. getdata(ϕuncc. logσ2_logP ))
416+ logσ2_ζP = vec(CA. getdata(ϕuncc. logσ2_ζP ))
396417 # CUDA cannot multiply BlockDiagonal * Diagonal, construct already those blocks
397418 σMs = reshape(exp.(logσ2_logMs ./ 2 ), n_θM, :)
398- σP = exp.(logσ2_logP ./ 2 )
419+ σP = exp.(logσ2_ζP ./ 2 )
399420 # BlockDiagonal does work with CUDA, but not with combination of Zygote and CUDA
400421 # need to construct full matrix for CUDA
401422 Uσ = _create_blockdiag(UP, UM, σP, σMs, n_batch)
0 commit comments