1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from collections .abc import Sequence
15- from typing import Literal
15+ from typing import Any , Literal
1616
1717from arviz import InferenceData
1818from xarray import Dataset
@@ -36,6 +36,7 @@ def compute_log_likelihood(
3636 model : Model | None = None ,
3737 sample_dims : Sequence [str ] = ("chain" , "draw" ),
3838 progressbar = True ,
39+ compile_kwargs : dict [str , Any ] | None = None ,
3940):
4041 """Compute elemwise log_likelihood of model given InferenceData with posterior group
4142
@@ -51,6 +52,8 @@ def compute_log_likelihood(
5152 model : Model, optional
5253 sample_dims : sequence of str, default ("chain", "draw")
5354 progressbar : bool, default True
55+ compile_kwargs : dict[str, Any] | None
56+ Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
5457
5558 Returns
5659 -------
@@ -65,6 +68,7 @@ def compute_log_likelihood(
6568 kind = "likelihood" ,
6669 sample_dims = sample_dims ,
6770 progressbar = progressbar ,
71+ compile_kwargs = compile_kwargs ,
6872 )
6973
7074
@@ -75,6 +79,7 @@ def compute_log_prior(
7579 model : Model | None = None ,
7680 sample_dims : Sequence [str ] = ("chain" , "draw" ),
7781 progressbar = True ,
82+ compile_kwargs = None ,
7883):
7984 """Compute elemwise log_prior of model given InferenceData with posterior group
8085
@@ -90,6 +95,8 @@ def compute_log_prior(
9095 model : Model, optional
9196 sample_dims : sequence of str, default ("chain", "draw")
9297 progressbar : bool, default True
98+ compile_kwargs : dict[str, Any] | None
99+ Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
93100
94101 Returns
95102 -------
@@ -104,6 +111,7 @@ def compute_log_prior(
104111 kind = "prior" ,
105112 sample_dims = sample_dims ,
106113 progressbar = progressbar ,
114+ compile_kwargs = compile_kwargs ,
107115 )
108116
109117
@@ -116,14 +124,42 @@ def compute_log_density(
116124 kind : Literal ["likelihood" , "prior" ] = "likelihood" ,
117125 sample_dims : Sequence [str ] = ("chain" , "draw" ),
118126 progressbar = True ,
127+ compile_kwargs = None ,
119128) -> InferenceData | Dataset :
120129 """
121130 Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
131+
132+ Parameters
133+ ----------
134+ idata : InferenceData
135+ InferenceData with posterior group
136+ var_names : sequence of str, optional
137+ List of Observed variable names for which to compute log_prior.
138+ Defaults to all all free variables.
139+ extend_inferencedata : bool, default True
140+ Whether to extend the original InferenceData or return a new one
141+ model : Model, optional
142+ kind: Literal["likelihood", "prior"]
143+ Whether to compute the log density of the observed random variables (likelihood)
144+ or to compute the log density of the latent random variables (prior). This
145+ parameter determines the group that gets added to the returned `~arviz.InferenceData` object.
146+ sample_dims : sequence of str, default ("chain", "draw")
147+ progressbar : bool, default True
148+ compile_kwargs : dict[str, Any] | None
149+ Extra compilation arguments to supply to :py:func:`pymc.model.core.Model.compile_fn`
150+
151+ Returns
152+ -------
153+ idata : InferenceData
154+ InferenceData with the ``log_likelihood`` group when ``kind == "likelihood"``
155+ or the ``log_prior`` group when ``kind == "prior"``.
122156 """
123157
124158 posterior = idata ["posterior" ]
125159
126160 model = modelcontext (model )
161+ if compile_kwargs is None :
162+ compile_kwargs = {}
127163
128164 if kind not in ("likelihood" , "prior" ):
129165 raise ValueError ("kind must be either 'likelihood' or 'prior'" )
@@ -150,6 +186,7 @@ def compute_log_density(
150186 inputs = umodel .value_vars ,
151187 outs = umodel .logp (vars = vars , sum = False ),
152188 on_unused_input = "ignore" ,
189+ ** compile_kwargs ,
153190 )
154191
155192 coords , dims = coords_and_dims_for_inferencedata (umodel )
0 commit comments