Skip to content

Commit ae5a943

Browse files
committed
add log_prior to arviz converter
1 parent 1eef3d2 commit ae5a943

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

pymc/backends/arviz.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __init__(
174174
prior=None,
175175
posterior_predictive=None,
176176
log_likelihood=False,
177+
log_prior=False,
177178
predictions=None,
178179
coords: Optional[CoordSpec] = None,
179180
dims: Optional[DimSpec] = None,
@@ -215,6 +216,7 @@ def __init__(
215216
self.prior = prior
216217
self.posterior_predictive = posterior_predictive
217218
self.log_likelihood = log_likelihood
219+
self.log_prior = log_prior
218220
self.predictions = predictions
219221

220222
if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)):
@@ -446,6 +448,17 @@ def to_inference_data(self):
446448
sample_dims=self.sample_dims,
447449
progressbar=False,
448450
)
451+
if self.log_prior:
452+
from pymc.stats.log_density import compute_log_prior
453+
454+
idata = compute_log_prior(
455+
idata,
456+
var_names=None if self.log_prior is True else self.log_prior,
457+
extend_inferencedata=True,
458+
model=self.model,
459+
sample_dims=self.sample_dims,
460+
progressbar=False,
461+
)
449462
return idata
450463

451464

@@ -455,6 +468,7 @@ def to_inference_data(
455468
prior: Optional[Mapping[str, Any]] = None,
456469
posterior_predictive: Optional[Mapping[str, Any]] = None,
457470
log_likelihood: Union[bool, Iterable[str]] = False,
471+
log_prior: Union[bool, Iterable[str]] = False,
458472
coords: Optional[CoordSpec] = None,
459473
dims: Optional[DimSpec] = None,
460474
sample_dims: Optional[list] = None,
@@ -481,8 +495,11 @@ def to_inference_data(
481495
Dictionary with the variable names as keys, and values numpy arrays
482496
containing posterior predictive samples.
483497
log_likelihood : bool or array_like of str, optional
484-
List of variables to calculate `log_likelihood`. Defaults to True which calculates
485-
`log_likelihood` for all observed variables. If set to False, log_likelihood is skipped.
498+
List of variables to calculate `log_likelihood`. Defaults to False.
499+
If set to True, computes `log_likelihood` for all observed variables.
500+
log_prior : bool or array_like of str, optional
501+
List of variables to calculate `log_prior`. Defaults to False.
502+
If set to True, computes `log_prior` for all unobserved variables.
486503
coords : dict of {str: array-like}, optional
487504
Map of coordinate names to coordinate values
488505
dims : dict of {str: list of str}, optional
@@ -509,6 +526,7 @@ def to_inference_data(
509526
prior=prior,
510527
posterior_predictive=posterior_predictive,
511528
log_likelihood=log_likelihood,
529+
log_prior=log_prior,
512530
coords=coords,
513531
dims=dims,
514532
sample_dims=sample_dims,

0 commit comments

Comments
 (0)