Skip to content

Commit a7959b9

Browse files
committed
attach CA axis to result of sample_posterior
1 parent 6546a1a commit a7959b9

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/elbo.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,12 @@ for hybrid variational inference problem.
186186
- `rng`: random number generator
187187
- `prob`: The AbstractHybridProblem from to sample
188188
- `xM`: covariates for the machine-learning model (ML): Matrix `(n_θM x n_site_pred)`.
189-
Default to all sites in train_dataloader in prob.
189+
Default to all sites in `get_hybridproblem_train_dataloader(prob; scenario)`.
190190
191191
Optional keyword arguments
192192
- `scenario`: scenario to query `prob` and set default of gpu devices.
193193
- `n_sample_pred`: number of samples to draw, defaults to 200
194-
- `gdevs`: NamedTuple(gdev_M, gdev_P): GPU devices for machine learning model
194+
- `gdevs`: `NamedTuple(gdev_M, gdev_P)`: GPU devices for machine learning model
195195
and parameter transformtation, default to [`get_gdev_MP`](@ref)`(scenario)`.
196196
- `is_inferred`: set to `Val(true)` to activate type stabilicy check for transformation
197197
@@ -238,9 +238,12 @@ function sample_posterior(rng, prob::AbstractHybridProblem, xM::AbstractMatrix;
238238
int_unc = interpreters.unc
239239
transMs = StackedArray(transM, n_batch)
240240
g_dev, ϕ_dev = gdevs.gdev_M(g), gdevs.gdev_M(ϕ)
241-
sample_posterior(rng, g_dev, ϕ_dev, xM;
241+
(; θsP, θsMs, entropy_ζ) = sample_posterior(rng, g_dev, ϕ_dev, xM;
242242
int_μP_ϕg_unc, int_unc, transP, transM,
243243
n_sample_pred, cdev=infer_cdev(gdevs), cor_ends, pbm_covar_indices, kwargs...)
244+
θsPc = ComponentArrayInterpreter(prob.θP, (n_sample_pred,))(θsP)
245+
θsMsc = ComponentArrayInterpreter((n_site,), prob.θM, (n_sample_pred,))(θsMs)
246+
(; θsP=θsPc, θsMs=θsMsc, entropy_ζ)
244247
end
245248

246249
function sample_posterior(rng, g, ϕ::AbstractVector, xM::AbstractMatrix;

0 commit comments

Comments
 (0)