Skip to content

Commit 726d18b

Browse files
committed
unify
1 parent d6297ba commit 726d18b

File tree

8 files changed

+69
-175
lines changed

8 files changed

+69
-175
lines changed

docs/source/api/misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ Other utils
77
:toctree: generated/
88

99
compute_log_likelihood
10+
compute_log_prior
1011
find_constrained_prior
1112
DictToArrayBijection

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def to_inference_data(self):
436436
id_dict["constant_data"] = self.constant_data_to_xarray()
437437
idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
438438
if self.log_likelihood:
439-
from pymc.stats.log_likelihood import compute_log_likelihood
439+
from pymc.stats.log_density import compute_log_likelihood
440440

441441
idata = compute_log_likelihood(
442442
idata,

pymc/stats/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
if not attr.startswith("__"):
2828
setattr(sys.modules[__name__], attr, obj)
2929

30-
from pymc.stats.log_likelihood import compute_log_likelihood
31-
from pymc.stats.log_prior import compute_log_prior
30+
from pymc.stats.log_density import compute_log_likelihood, compute_log_prior
3231

3332
__all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__)

pymc/stats/log_density.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@
2424
from pymc.pytensorf import PointFunc
2525
from pymc.util import dataset_to_point_list
2626

27+
__all__ = ("compute_log_likelihood", "compute_log_prior")
28+
29+
30+
def compute_log_likelihood(
31+
idata: InferenceData,
32+
*,
33+
var_names: Optional[Sequence[str]] = None,
34+
extend_inferencedata: bool = True,
35+
model: Optional[Model] = None,
36+
sample_dims: Sequence[str] = ("chain", "draw"),
37+
progressbar=True,
38+
):
39+
"""Compute elemwise log_likelihood of model given InferenceData with posterior group
40+
41+
Parameters
42+
----------
43+
idata : InferenceData
44+
InferenceData with posterior group
45+
var_names : sequence of str, optional
46+
List of Observed variable names for which to compute log_likelihood.
47+
Defaults to all observed variables.
48+
extend_inferencedata : bool, default True
49+
Whether to extend the original InferenceData or return a new one
50+
model : Model, optional
51+
sample_dims : sequence of str, default ("chain", "draw")
52+
progressbar : bool, default True
53+
54+
Returns
55+
-------
56+
idata : InferenceData
57+
InferenceData with log_likelihood group
58+
"""
59+
return compute_log_density(
60+
idata=idata,
61+
var_names=var_names,
62+
extend_inferencedata=extend_inferencedata,
63+
model=model,
64+
kind="likelihood",
65+
sample_dims=sample_dims,
66+
progressbar=progressbar,
67+
)
68+
2769

2870
def compute_log_prior(
2971
idata: InferenceData,

pymc/stats/log_likelihood.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

pymc/stats/log_prior.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

tests/stats/test_log_likelihood.py renamed to tests/stats/test_log_density.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pymc.distributions import Dirichlet, Normal
2121
from pymc.distributions.transforms import log
2222
from pymc.model import Model
23-
from pymc.stats.log_likelihood import compute_log_likelihood
23+
from pymc.stats.log_density import compute_log_likelihood, compute_log_prior
2424
from tests.distributions.test_multivariate import dirichlet_logpdf
2525

2626

@@ -132,3 +132,26 @@ def test_dims_without_coords(self):
132132
llike.log_likelihood["y"].values,
133133
st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]),
134134
)
135+
136+
@pytest.mark.parametrize("transform", (False, True))
137+
def test_basic_log_prior(self, transform):
138+
transform = log if transform else None
139+
with Model() as m:
140+
x = Normal("x", transform=transform)
141+
x_value_var = m.rvs_to_values[x]
142+
Normal("y", x, observed=[0, 1, 2])
143+
144+
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
145+
res = compute_log_prior(idata)
146+
147+
# Check we didn't erase the original mappings
148+
assert m.rvs_to_values[x] is x_value_var
149+
assert m.rvs_to_transforms[x] is transform
150+
151+
assert res is idata
152+
assert res.log_prior.dims == {"chain": 4, "draw": 25}
153+
154+
np.testing.assert_allclose(
155+
res.log_prior["x"].values,
156+
st.norm(0, 1).logpdf(idata.posterior["x"].values),
157+
)

tests/stats/test_log_prior.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

0 commit comments

Comments
 (0)