@@ -186,12 +186,12 @@ for hybrid variational inference problem.
186186- `rng`: random number generator
187187- `prob`: The AbstractHybridProblem from to sample
188188- `xM`: covariates for the machine-learning model (ML): Matrix `(n_θM x n_site_pred)`.
189- Default to all sites in train_dataloader in prob.
189+ Default to all sites in `get_hybridproblem_train_dataloader( prob; scenario)` .
190190
191191Optional keyword arguments
192192- `scenario`: scenario to query `prob` and set default of gpu devices.
193193- `n_sample_pred`: number of samples to draw, defaults to 200
194- - `gdevs`: NamedTuple(gdev_M, gdev_P): GPU devices for machine learning model
194+ - `gdevs`: ` NamedTuple(gdev_M, gdev_P)` : GPU devices for machine learning model
195195 and parameter transformtation, default to [`get_gdev_MP`](@ref)`(scenario)`.
196196- `is_inferred`: set to `Val(true)` to activate type stabilicy check for transformation
197197
@@ -238,9 +238,12 @@ function sample_posterior(rng, prob::AbstractHybridProblem, xM::AbstractMatrix;
238238 int_unc = interpreters. unc
239239 transMs = StackedArray (transM, n_batch)
240240 g_dev, ϕ_dev = gdevs. gdev_M (g), gdevs. gdev_M (ϕ)
241- sample_posterior (rng, g_dev, ϕ_dev, xM;
241+ (; θsP, θsMs, entropy_ζ) = sample_posterior (rng, g_dev, ϕ_dev, xM;
242242 int_μP_ϕg_unc, int_unc, transP, transM,
243243 n_sample_pred, cdev= infer_cdev (gdevs), cor_ends, pbm_covar_indices, kwargs... )
244+ θsPc = ComponentArrayInterpreter (prob. θP, (n_sample_pred,))(θsP)
245+ θsMsc = ComponentArrayInterpreter ((n_site,), prob. θM, (n_sample_pred,))(θsMs)
246+ (; θsP= θsPc, θsMs= θsMsc, entropy_ζ)
244247end
245248
246249function sample_posterior (rng, g, ϕ:: AbstractVector , xM:: AbstractMatrix ;
0 commit comments