Skip to content

Commit 871dfb3

Browse files
pdb5627ricardoV94
authored andcommitted
Pass kwargs through to pymc sample calls
Fixes #184 by allowing `random_seed` to be set in calls to `predict` in `test_save_load` for `LinearModel`.
1 parent 0ef82df commit 871dfb3

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

pymc_experimental/model_builder.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def predict(
468468
self,
469469
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],
470470
extend_idata: bool = True,
471+
**kwargs,
471472
) -> np.ndarray:
472473
"""
473474
Uses model to predict on unseen data and return point prediction of all the samples. The point prediction
@@ -479,6 +480,7 @@ def predict(
479480
The input data used for prediction.
480481
extend_idata : Boolean determining whether the predictions should be added to inference data object.
481482
Defaults to True.
483+
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
482484
483485
Returns
484486
-------
@@ -495,7 +497,7 @@ def predict(
495497
"""
496498

497499
posterior_predictive_samples = self.sample_posterior_predictive(
498-
X_pred, extend_idata, combined=False
500+
X_pred, extend_idata, combined=False, **kwargs
499501
)
500502

501503
if self.output_var not in posterior_predictive_samples:
@@ -514,6 +516,7 @@ def sample_prior_predictive(
514516
samples: Optional[int] = None,
515517
extend_idata: bool = False,
516518
combined: bool = True,
519+
**kwargs,
517520
):
518521
"""
519522
Sample from the model's prior predictive distribution.
@@ -529,6 +532,7 @@ def sample_prior_predictive(
529532
Defaults to False.
530533
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
531534
Defaults to True.
535+
**kwargs: Additional arguments to pass to pymc.sample_prior_predictive
532536
533537
Returns
534538
-------
@@ -544,7 +548,7 @@ def sample_prior_predictive(
544548
self._data_setter(X_pred)
545549
if self.model is not None:
546550
with self.model: # sample with new input data
547-
prior_pred: az.InferenceData = pm.sample_prior_predictive(samples)
551+
prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs)
548552
self.set_idata_attrs(prior_pred)
549553
if extend_idata:
550554
if self.idata is not None:
@@ -556,7 +560,7 @@ def sample_prior_predictive(
556560

557561
return prior_predictive_samples
558562

559-
def sample_posterior_predictive(self, X_pred, extend_idata, combined):
563+
def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
560564
"""
561565
Sample from the model's posterior predictive distribution.
562566
@@ -568,6 +572,7 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined):
568572
Defaults to False.
569573
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
570574
Defaults to True.
575+
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
571576
572577
Returns
573578
-------
@@ -577,7 +582,7 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined):
577582
self._data_setter(X_pred)
578583

579584
with self.model: # sample with new input data
580-
post_pred = pm.sample_posterior_predictive(self.idata)
585+
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
581586
if extend_idata:
582587
self.idata.extend(post_pred)
583588

@@ -621,15 +626,17 @@ def predict_proba(
621626
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],
622627
extend_idata: bool = True,
623628
combined: bool = False,
629+
**kwargs,
624630
) -> xr.DataArray:
625631
"""Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators."""
626-
return self.predict_posterior(X_pred, extend_idata, combined)
632+
return self.predict_posterior(X_pred, extend_idata, combined, **kwargs)
627633

628634
def predict_posterior(
629635
self,
630636
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],
631637
extend_idata: bool = True,
632638
combined: bool = True,
639+
**kwargs,
633640
) -> xr.DataArray:
634641
"""
635642
Generate posterior predictive samples on unseen data.
@@ -642,6 +649,7 @@ def predict_posterior(
642649
Defaults to True.
643650
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
644651
Defaults to True.
652+
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
645653
646654
Returns
647655
-------
@@ -651,7 +659,7 @@ def predict_posterior(
651659

652660
X_pred = self._validate_data(X_pred)
653661
posterior_predictive_samples = self.sample_posterior_predictive(
654-
X_pred, extend_idata, combined
662+
X_pred, extend_idata, combined, **kwargs
655663
)
656664

657665
if self.output_var not in posterior_predictive_samples:

pymc_experimental/tests/test_linearmodel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,10 @@ def test_save_load(fitted_linear_model_instance):
8282
assert model.idata.groups() == model2.idata.groups()
8383

8484
x_pred = np.random.uniform(low=0, high=1, size=(100, 1))
85-
pred1 = model.predict(x_pred)
86-
pred2 = model2.predict(x_pred)
87-
# Predictions should have similar statistical characteristics
88-
assert pred1.mean() == pytest.approx(pred2.mean(), 1e-3)
89-
assert pred1.var() == pytest.approx(pred2.var(), 1e-2)
85+
pred1 = model.predict(x_pred, random_seed=423)
86+
pred2 = model2.predict(x_pred, random_seed=423)
87+
# Predictions should be identical
88+
np.testing.assert_array_equal(pred1, pred2)
9089
temp.close()
9190

9291

0 commit comments

Comments
 (0)