@@ -468,6 +468,7 @@ def predict(
468
468
self ,
469
469
X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
470
470
extend_idata : bool = True ,
471
+ ** kwargs ,
471
472
) -> np .ndarray :
472
473
"""
473
474
Uses model to predict on unseen data and return point prediction of all the samples. The point prediction
@@ -479,6 +480,7 @@ def predict(
479
480
The input data used for prediction.
480
481
extend_idata : Boolean determining whether the predictions should be added to inference data object.
481
482
Defaults to True.
483
+ **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
482
484
483
485
Returns
484
486
-------
@@ -495,7 +497,7 @@ def predict(
495
497
"""
496
498
497
499
posterior_predictive_samples = self .sample_posterior_predictive (
498
- X_pred , extend_idata , combined = False
500
+ X_pred , extend_idata , combined = False , ** kwargs
499
501
)
500
502
501
503
if self .output_var not in posterior_predictive_samples :
@@ -514,6 +516,7 @@ def sample_prior_predictive(
514
516
samples : Optional [int ] = None ,
515
517
extend_idata : bool = False ,
516
518
combined : bool = True ,
519
+ ** kwargs ,
517
520
):
518
521
"""
519
522
Sample from the model's prior predictive distribution.
@@ -529,6 +532,7 @@ def sample_prior_predictive(
529
532
Defaults to False.
530
533
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
531
534
Defaults to True.
535
+ **kwargs: Additional arguments to pass to pymc.sample_prior_predictive
532
536
533
537
Returns
534
538
-------
@@ -544,7 +548,7 @@ def sample_prior_predictive(
544
548
self ._data_setter (X_pred )
545
549
if self .model is not None :
546
550
with self .model : # sample with new input data
547
- prior_pred : az .InferenceData = pm .sample_prior_predictive (samples )
551
+ prior_pred : az .InferenceData = pm .sample_prior_predictive (samples , ** kwargs )
548
552
self .set_idata_attrs (prior_pred )
549
553
if extend_idata :
550
554
if self .idata is not None :
@@ -556,7 +560,7 @@ def sample_prior_predictive(
556
560
557
561
return prior_predictive_samples
558
562
559
- def sample_posterior_predictive (self , X_pred , extend_idata , combined ):
563
+ def sample_posterior_predictive (self , X_pred , extend_idata , combined , ** kwargs ):
560
564
"""
561
565
Sample from the model's posterior predictive distribution.
562
566
@@ -568,6 +572,7 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined):
568
572
Defaults to False.
569
573
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
570
574
Defaults to True.
575
+ **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
571
576
572
577
Returns
573
578
-------
@@ -577,7 +582,7 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined):
577
582
self ._data_setter (X_pred )
578
583
579
584
with self .model : # sample with new input data
580
- post_pred = pm .sample_posterior_predictive (self .idata )
585
+ post_pred = pm .sample_posterior_predictive (self .idata , ** kwargs )
581
586
if extend_idata :
582
587
self .idata .extend (post_pred )
583
588
@@ -621,15 +626,17 @@ def predict_proba(
621
626
X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
622
627
extend_idata : bool = True ,
623
628
combined : bool = False ,
629
+ ** kwargs ,
624
630
) -> xr .DataArray :
625
631
"""Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators."""
626
- return self .predict_posterior (X_pred , extend_idata , combined )
632
+ return self .predict_posterior (X_pred , extend_idata , combined , ** kwargs )
627
633
628
634
def predict_posterior (
629
635
self ,
630
636
X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
631
637
extend_idata : bool = True ,
632
638
combined : bool = True ,
639
+ ** kwargs ,
633
640
) -> xr .DataArray :
634
641
"""
635
642
Generate posterior predictive samples on unseen data.
@@ -642,6 +649,7 @@ def predict_posterior(
642
649
Defaults to True.
643
650
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
644
651
Defaults to True.
652
+ **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
645
653
646
654
Returns
647
655
-------
@@ -651,7 +659,7 @@ def predict_posterior(
651
659
652
660
X_pred = self ._validate_data (X_pred )
653
661
posterior_predictive_samples = self .sample_posterior_predictive (
654
- X_pred , extend_idata , combined
662
+ X_pred , extend_idata , combined , ** kwargs
655
663
)
656
664
657
665
if self .output_var not in posterior_predictive_samples :
0 commit comments