Skip to content

Commit d6297ba

Browse files
committed
add compute_log_prior
1 parent 6657169 commit d6297ba

File tree

5 files changed

+285
-81
lines changed

5 files changed

+285
-81
lines changed

pymc/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@
2828
setattr(sys.modules[__name__], attr, obj)
2929

3030
from pymc.stats.log_likelihood import compute_log_likelihood
31+
from pymc.stats.log_prior import compute_log_prior
3132

32-
__all__ = ("compute_log_likelihood", *az.stats.__all__)
33+
__all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__)

pymc/stats/log_density.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Sequence
15+
from typing import Optional, cast
16+
17+
from arviz import InferenceData, dict_to_dataset
18+
from fastprogress import progress_bar
19+
20+
import pymc
21+
22+
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
23+
from pymc.model import Model, modelcontext
24+
from pymc.pytensorf import PointFunc
25+
from pymc.util import dataset_to_point_list
26+
27+
28+
def compute_log_prior(
29+
idata: InferenceData,
30+
var_names: Optional[Sequence[str]] = None,
31+
extend_inferencedata: bool = True,
32+
model: Optional[Model] = None,
33+
sample_dims: Sequence[str] = ("chain", "draw"),
34+
progressbar=True,
35+
):
36+
"""Compute elemwise log_prior of model given InferenceData with posterior group
37+
38+
Parameters
39+
----------
40+
idata : InferenceData
41+
InferenceData with posterior group
42+
var_names : sequence of str, optional
43+
List of Observed variable names for which to compute log_prior.
44+
Defaults to all all free variables.
45+
extend_inferencedata : bool, default True
46+
Whether to extend the original InferenceData or return a new one
47+
model : Model, optional
48+
sample_dims : sequence of str, default ("chain", "draw")
49+
progressbar : bool, default True
50+
51+
Returns
52+
-------
53+
idata : InferenceData
54+
InferenceData with log_prior group
55+
"""
56+
return compute_log_density(
57+
idata=idata,
58+
var_names=var_names,
59+
extend_inferencedata=extend_inferencedata,
60+
model=model,
61+
kind="prior",
62+
sample_dims=sample_dims,
63+
progressbar=progressbar,
64+
)
65+
66+
67+
def compute_log_density(
68+
idata: InferenceData,
69+
*,
70+
var_names: Optional[Sequence[str]] = None,
71+
extend_inferencedata: bool = True,
72+
model: Optional[Model] = None,
73+
kind="likelihood",
74+
sample_dims: Sequence[str] = ("chain", "draw"),
75+
progressbar=True,
76+
):
77+
"""
78+
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
79+
"""
80+
81+
posterior = idata["posterior"]
82+
83+
model = modelcontext(model)
84+
85+
if kind not in ("likelihood", "prior"):
86+
raise ValueError("kind must be either 'likelihood' or 'prior'")
87+
88+
if kind == "likelihood":
89+
target_rvs = model.observed_RVs
90+
target_str = "observed_RVs"
91+
else:
92+
target_rvs = model.unobserved_RVs
93+
target_str = "free_RVs"
94+
95+
if var_names is None:
96+
vars = target_rvs
97+
var_names = tuple(rv.name for rv in vars)
98+
else:
99+
vars = [model.named_vars[name] for name in var_names]
100+
if not set(vars).issubset(target_rvs):
101+
raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")
102+
103+
# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
104+
try:
105+
original_rvs_to_values = model.rvs_to_values
106+
original_rvs_to_transforms = model.rvs_to_transforms
107+
108+
model.rvs_to_values = {
109+
rv: rv.clone() if rv not in model.observed_RVs else value
110+
for rv, value in model.rvs_to_values.items()
111+
}
112+
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}
113+
114+
elemwise_logdens_fn = model.compile_fn(
115+
inputs=model.value_vars,
116+
outs=model.logp(vars=vars, sum=False),
117+
on_unused_input="ignore",
118+
)
119+
elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn)
120+
finally:
121+
model.rvs_to_values = original_rvs_to_values
122+
model.rvs_to_transforms = original_rvs_to_transforms
123+
124+
# Ignore Deterministics
125+
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
126+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
127+
128+
n_pts = len(posterior_pts)
129+
logdens_dict = _DefaultTrace(n_pts)
130+
indices = range(n_pts)
131+
if progressbar:
132+
indices = progress_bar(indices, total=n_pts, display=progressbar)
133+
134+
for idx in indices:
135+
logdenss_pts = elemwise_logdens_fn(posterior_pts[idx])
136+
for rv_name, rv_logdens in zip(var_names, logdenss_pts):
137+
logdens_dict.insert(rv_name, rv_logdens, idx)
138+
139+
logdens_trace = logdens_dict.trace_dict
140+
for key, array in logdens_trace.items():
141+
logdens_trace[key] = array.reshape(
142+
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
143+
)
144+
145+
coords, dims = coords_and_dims_for_inferencedata(model)
146+
logdens_dataset = dict_to_dataset(
147+
logdens_trace,
148+
library=pymc,
149+
dims=dims,
150+
coords=coords,
151+
default_dims=list(sample_dims),
152+
skip_event_dims=True,
153+
)
154+
155+
if extend_inferencedata:
156+
idata.add_groups({f"log_{kind}": logdens_dataset})
157+
return idata
158+
else:
159+
return logdens_dataset

pymc/stats/log_likelihood.py

Lines changed: 15 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Sequence
15-
from typing import Optional, cast
15+
from typing import Optional
1616

17-
from arviz import InferenceData, dict_to_dataset
18-
from fastprogress import progress_bar
17+
from arviz import InferenceData
1918

20-
import pymc
19+
from pymc.model import Model
20+
from pymc.stats.log_density import compute_log_density
2121

22-
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
23-
from pymc.model import Model, modelcontext
24-
from pymc.pytensorf import PointFunc
25-
from pymc.util import dataset_to_point_list
26-
27-
__all__ = ("compute_log_likelihood",)
22+
__all__ = "compute_log_likelihood"
2823

2924

3025
def compute_log_likelihood(
@@ -43,7 +38,8 @@ def compute_log_likelihood(
4338
idata : InferenceData
4439
InferenceData with posterior group
4540
var_names : sequence of str, optional
46-
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
41+
List of Observed variable names for which to compute log_likelihood.
42+
Defaults to all observed variables.
4743
extend_inferencedata : bool, default True
4844
Whether to extend the original InferenceData or return a new one
4945
model : Model, optional
@@ -54,74 +50,13 @@ def compute_log_likelihood(
5450
-------
5551
idata : InferenceData
5652
InferenceData with log_likelihood group
57-
5853
"""
59-
60-
posterior = idata["posterior"]
61-
62-
model = modelcontext(model)
63-
64-
if var_names is None:
65-
observed_vars = model.observed_RVs
66-
var_names = tuple(rv.name for rv in observed_vars)
67-
else:
68-
observed_vars = [model.named_vars[name] for name in var_names]
69-
if not set(observed_vars).issubset(model.observed_RVs):
70-
raise ValueError(f"var_names must refer to observed_RVs in the model. Got: {var_names}")
71-
72-
# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
73-
try:
74-
original_rvs_to_values = model.rvs_to_values
75-
original_rvs_to_transforms = model.rvs_to_transforms
76-
77-
model.rvs_to_values = {
78-
rv: rv.clone() if rv not in model.observed_RVs else value
79-
for rv, value in model.rvs_to_values.items()
80-
}
81-
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}
82-
83-
elemwise_loglike_fn = model.compile_fn(
84-
inputs=model.value_vars,
85-
outs=model.logp(vars=observed_vars, sum=False),
86-
on_unused_input="ignore",
87-
)
88-
elemwise_loglike_fn = cast(PointFunc, elemwise_loglike_fn)
89-
finally:
90-
model.rvs_to_values = original_rvs_to_values
91-
model.rvs_to_transforms = original_rvs_to_transforms
92-
93-
# Ignore Deterministics
94-
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
95-
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
96-
n_pts = len(posterior_pts)
97-
loglike_dict = _DefaultTrace(n_pts)
98-
indices = range(n_pts)
99-
if progressbar:
100-
indices = progress_bar(indices, total=n_pts, display=progressbar)
101-
102-
for idx in indices:
103-
loglikes_pts = elemwise_loglike_fn(posterior_pts[idx])
104-
for rv_name, rv_loglike in zip(var_names, loglikes_pts):
105-
loglike_dict.insert(rv_name, rv_loglike, idx)
106-
107-
loglike_trace = loglike_dict.trace_dict
108-
for key, array in loglike_trace.items():
109-
loglike_trace[key] = array.reshape(
110-
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
111-
)
112-
113-
coords, dims = coords_and_dims_for_inferencedata(model)
114-
loglike_dataset = dict_to_dataset(
115-
loglike_trace,
116-
library=pymc,
117-
dims=dims,
118-
coords=coords,
119-
default_dims=list(sample_dims),
120-
skip_event_dims=True,
54+
return compute_log_density(
55+
idata=idata,
56+
var_names=var_names,
57+
extend_inferencedata=extend_inferencedata,
58+
model=model,
59+
kind="likelihood",
60+
sample_dims=sample_dims,
61+
progressbar=progressbar,
12162
)
122-
123-
if extend_inferencedata:
124-
idata.add_groups(dict(log_likelihood=loglike_dataset))
125-
return idata
126-
else:
127-
return loglike_dataset

pymc/stats/log_prior.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Sequence
15+
from typing import Optional
16+
17+
from arviz import InferenceData
18+
19+
from pymc.model import Model
20+
from pymc.stats.log_density import compute_log_density
21+
22+
__all__ = "compute_log_prior"
23+
24+
25+
def compute_log_prior(
26+
idata: InferenceData,
27+
var_names: Optional[Sequence[str]] = None,
28+
extend_inferencedata: bool = True,
29+
model: Optional[Model] = None,
30+
sample_dims: Sequence[str] = ("chain", "draw"),
31+
progressbar=True,
32+
):
33+
"""Compute elemwise log_prior of model given InferenceData with posterior group
34+
35+
Parameters
36+
----------
37+
idata : InferenceData
38+
InferenceData with posterior group
39+
var_names : sequence of str, optional
40+
List of Observed variable names for which to compute log_prior.
41+
Defaults to all all free variables.
42+
extend_inferencedata : bool, default True
43+
Whether to extend the original InferenceData or return a new one
44+
model : Model, optional
45+
sample_dims : sequence of str, default ("chain", "draw")
46+
progressbar : bool, default True
47+
48+
Returns
49+
-------
50+
idata : InferenceData
51+
InferenceData with log_prior group
52+
"""
53+
return compute_log_density(
54+
idata=idata,
55+
var_names=var_names,
56+
extend_inferencedata=extend_inferencedata,
57+
model=model,
58+
kind="prior",
59+
sample_dims=sample_dims,
60+
progressbar=progressbar,
61+
)

tests/stats/test_log_prior.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
import scipy.stats as st
17+
18+
from arviz import InferenceData, dict_to_dataset
19+
20+
from pymc.distributions import Normal
21+
from pymc.distributions.transforms import log
22+
from pymc.model import Model
23+
from pymc.stats.log_prior import compute_log_prior
24+
25+
26+
class TestComputeLogPrior:
27+
@pytest.mark.parametrize("transform", (False, True))
28+
def test_basic(self, transform):
29+
transform = log if transform else None
30+
with Model() as m:
31+
x = Normal("x", transform=transform)
32+
x_value_var = m.rvs_to_values[x]
33+
Normal("y", x, observed=[0, 1, 2])
34+
35+
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
36+
res = compute_log_prior(idata)
37+
38+
# Check we didn't erase the original mappings
39+
assert m.rvs_to_values[x] is x_value_var
40+
assert m.rvs_to_transforms[x] is transform
41+
42+
assert res is idata
43+
assert res.log_prior.dims == {"chain": 4, "draw": 25}
44+
45+
np.testing.assert_allclose(
46+
res.log_prior["x"].values,
47+
st.norm(0, 1).logpdf(idata.posterior["x"].values),
48+
)

0 commit comments

Comments
 (0)