Skip to content

Commit f09e13c

Browse files
committed
default predict_gf for all sites of dataloader of problem
1 parent 0f03d3a commit f09e13c

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/elbo.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ Prediction function for hybrid model. Returns an NamedTuple with entries
130130
- `θ`: ComponentArray `(n_θP + n_site * n_θM), n_sample_pred)` of PBM model parameters.
131131
- `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions.
132132
"""
133+
function predict_gf(rng, prob::AbstractHybridProblem; scenario, kwargs...)
134+
n_batch = get_hybridproblem_n_site(prob; scenario)
135+
data = first(get_hybridproblem_train_dataloader(prob; scenario, n_batch))
136+
predict_gf(rng, prob, data[1], data[2]; scenario, kwargs...)
137+
end
133138
function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
134139
scenario,
135140
n_sample_pred = 200,
@@ -146,11 +151,12 @@ function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
146151
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
147152
f = get_hybridproblem_PBmodel(prob; scenario)
148153
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
154+
pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars)
149155
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
150156
θP, θM, cor_ends, ϕg0, n_site; transP, transM, ϕunc0)
151157
g_dev, ϕ_dev = gdev(g), gdev(ϕ)
152158
predict_gf(rng, g_dev, f, ϕ_dev, xM, xP, interpreters;
153-
get_transPMs, get_ca_int_PMs, n_sample_pred, cdev, cor_ends, pbm_covars)
159+
get_transPMs, get_ca_int_PMs, n_sample_pred, cdev, cor_ends, pbm_covar_indices)
154160
end
155161

156162
function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP, interpreters;

0 commit comments

Comments
 (0)