@@ -130,6 +130,11 @@ Prediction function for hybrid model. Returns an NamedTuple with entries
130130- `θ`: ComponentArray `(n_θP + n_site * n_θM), n_sample_pred)` of PBM model parameters.
131131- `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions.
132132"""
133+ function predict_gf (rng, prob:: AbstractHybridProblem ; scenario, kwargs... )
134+ n_batch = get_hybridproblem_n_site (prob; scenario)
135+ data = first (get_hybridproblem_train_dataloader (prob; scenario, n_batch))
136+ predict_gf (rng, prob, data[1 ], data[2 ]; scenario, kwargs... )
137+ end
133138function predict_gf (rng, prob:: AbstractHybridProblem , xM:: AbstractMatrix , xP;
134139 scenario,
135140 n_sample_pred = 200 ,
@@ -146,11 +151,12 @@ function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
146151 (; transP, transM) = get_hybridproblem_transforms (prob; scenario)
147152 f = get_hybridproblem_PBmodel (prob; scenario)
148153 pbm_covars = get_hybridproblem_pbmpar_covars (prob; scenario)
154+ pbm_covar_indices = get_pbm_covar_indices (θP, pbm_covars)
149155 (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params (
150156 θP, θM, cor_ends, ϕg0, n_site; transP, transM, ϕunc0)
151157 g_dev, ϕ_dev = gdev (g), gdev (ϕ)
152158 predict_gf (rng, g_dev, f, ϕ_dev, xM, xP, interpreters;
153- get_transPMs, get_ca_int_PMs, n_sample_pred, cdev, cor_ends, pbm_covars )
159+ get_transPMs, get_ca_int_PMs, n_sample_pred, cdev, cor_ends, pbm_covar_indices )
154160end
155161
156162function predict_gf (rng, g, f, ϕ:: AbstractVector , xM:: AbstractMatrix , xP, interpreters;
0 commit comments