Skip to content

Commit 34839ad

Browse files
committed
Fix group selection for posterior predictive samples when predictions = True
1 parent 00a4ca3 commit 34839ad

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pymc_extras/model_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,11 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
650650
if extend_idata:
651651
self.idata.extend(post_pred, join="right")
652652

653+
# Determine the correct group dynamically
654+
group_name = "predictions" if kwargs.get("predictions", False) else "posterior_predictive"
655+
653656
posterior_predictive_samples = az.extract(
654-
post_pred, "posterior_predictive", combined=combined
657+
post_pred, group_name, combined=combined
655658
)
656659

657660
return posterior_predictive_samples

0 commit comments

Comments
 (0)