11from functools import reduce
22from itertools import product
3- from typing import Literal
3+ from typing import Any , Literal
44
55import arviz as az
66import numpy as np
@@ -35,6 +35,29 @@ def make_unpacked_variable_names(name, model: pm.Model) -> list[str]:
3535 return [f"{ name } [{ ',' .join (map (str , label ))} ]" for label in labels ]
3636
3737
38+ def map_results_to_inferece_data (results : dict [str , Any ], model : pm .Model | None = None ):
39+ """
40+ Convert a dictionary of results to an InferenceData object.
41+
42+ Parameters
43+ ----------
44+ results: dict
45+ A dictionary containing the results to convert.
46+ model: Model, optional
47+ A PyMC model. If None, the model is taken from the current model context.
48+
49+ Returns
50+ -------
51+ idata: az.InferenceData
52+ An InferenceData object containing the results.
53+ """
54+ model = pm .modelcontext (model )
55+ coords , dims = coords_and_dims_for_inferencedata (model )
56+
57+ idata = az .convert_to_inference_data (results , coords = coords , dims = dims )
58+ return idata
59+
60+
3861def laplace_draws_to_inferencedata (
3962 posterior_draws : list [np .ndarray [float | int ]], model : pm .Model | None = None
4063) -> az .InferenceData :
@@ -91,13 +114,67 @@ def make_rv_dims(name):
91114 return idata
92115
93116
94- def add_fit_to_inferencedata (
117+ def add_map_posterior_to_inference_data (
118+ idata : az .InferenceData ,
119+ map_point : dict [str , float | int | np .ndarray ],
120+ model : pm .Model | None = None ,
121+ ):
122+ """
123+ Add the MAP point to an InferenceData object in the posterior group.
124+
125+ Unlike a typical posterior, the MAP point is a single point estimate rather than a distribution. As a result, it
126+ does not have a chain or draw dimension, and is stored as a single point in the posterior group.
127+
128+ Parameters
129+ ----------
130+ idata: az.InferenceData
131+ An InferenceData object to which the MAP point will be added.
132+ map_point: dict
133+ A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
134+ the values should be the corresponding MAP estimates.
135+ model: Model, optional
136+ A PyMC model. If None, the model is taken from the current model context.
137+
138+ Returns
139+ -------
140+ idata: az.InferenceData
141+ The provided InferenceData, with the MAP point added to the posterior group.
142+ """
143+
144+ model = pm .modelcontext (model ) if model is None else model
145+ coords , dims = coords_and_dims_for_inferencedata (model )
146+
147+ # The MAP point will have both the transformed and untransformed variables, so we need to ensure that
148+ # we have the correct dimensions for each variable.
149+ var_name_to_value_name = {rv .name : value .name for rv , value in model .rvs_to_values .items ()}
150+ dims .update (
151+ {
152+ value_name : dims [var_name ]
153+ for var_name , value_name in var_name_to_value_name .items ()
154+ if var_name in dims
155+ }
156+ )
157+
158+ posterior_data = {
159+ name : xr .DataArray (
160+ data = np .asarray (value ),
161+ coords = {dim : coords [dim ] for dim in dims .get (name , [])},
162+ dims = dims .get (name ),
163+ name = name ,
164+ )
165+ for name , value in map_point .items ()
166+ }
167+ idata .add_groups (posterior = xr .Dataset (posterior_data ))
168+
169+ return idata
170+
171+
172+ def add_fit_to_inference_data (
95173 idata : az .InferenceData , mu : RaveledVars , H_inv : np .ndarray , model : pm .Model | None = None
96174) -> az .InferenceData :
97175 """
98176 Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
99177
100-
101178 Parameters
102179 ----------
103180 idata: az.InfereceData
@@ -123,19 +200,24 @@ def add_fit_to_inferencedata(
123200 )
124201
125202 mean_dataarray = xr .DataArray (mu .data , dims = ["rows" ], coords = {"rows" : unpacked_variable_names })
126- cov_dataarray = xr .DataArray (
127- H_inv ,
128- dims = ["rows" , "columns" ],
129- coords = {"rows" : unpacked_variable_names , "columns" : unpacked_variable_names },
130- )
131203
132- dataset = xr .Dataset ({"mean_vector" : mean_dataarray , "covariance_matrix" : cov_dataarray })
204+ data = {"mean_vector" : mean_dataarray }
205+
206+ if H_inv is not None :
207+ cov_dataarray = xr .DataArray (
208+ H_inv ,
209+ dims = ["rows" , "columns" ],
210+ coords = {"rows" : unpacked_variable_names , "columns" : unpacked_variable_names },
211+ )
212+ data ["covariance_matrix" ] = cov_dataarray
213+
214+ dataset = xr .Dataset (data )
133215 idata .add_groups (fit = dataset )
134216
135217 return idata
136218
137219
138- def add_data_to_inferencedata (
220+ def add_data_to_inference_data (
139221 idata : az .InferenceData ,
140222 progressbar : bool = True ,
141223 model : pm .Model | None = None ,
@@ -163,8 +245,14 @@ def add_data_to_inferencedata(
163245 model = pm .modelcontext (model )
164246
165247 if model .deterministics :
248+ expand_dims = {}
249+ if "chain" not in idata .posterior .coords :
250+ expand_dims ["chain" ] = [0 ]
251+ if "draw" not in idata .posterior .coords :
252+ expand_dims ["draw" ] = [0 ]
253+
166254 idata .posterior = pm .compute_deterministics (
167- idata .posterior ,
255+ idata .posterior . expand_dims ( expand_dims ) ,
168256 model = model ,
169257 merge_dataset = True ,
170258 progressbar = progressbar ,
@@ -299,3 +387,37 @@ def optimizer_result_to_dataset(
299387 data_vars ["method" ] = xr .DataArray (np .array (method ), dims = [])
300388
301389 return xr .Dataset (data_vars )
390+
391+
392+ def add_optimizer_result_to_inference_data (
393+ idata : az .InferenceData ,
394+ result : OptimizeResult ,
395+ method : minimize_method | Literal ["basinhopping" ],
396+ mu : RaveledVars | None = None ,
397+ model : pm .Model | None = None ,
398+ ) -> az .InferenceData :
399+ """
400+ Add the optimization result to an InferenceData object.
401+
402+ Parameters
403+ ----------
404+ idata: az.InferenceData
405+ An InferenceData object containing the approximated posterior samples.
406+ result: OptimizeResult
407+ The result of the optimization process.
408+ method: minimize_method or "basinhopping"
409+ The optimization method used.
410+ mu: RaveledVars, optional
411+ The MAP estimate of the model parameters.
412+ model: Model, optional
413+ A PyMC model. If None, the model is taken from the current model context.
414+
415+ Returns
416+ -------
417+ idata: az.InferenceData
418+ The provided InferenceData, with the optimization results added to the "optimizer" group.
419+ """
420+ dataset = optimizer_result_to_dataset (result , method = method , mu = mu , model = model )
421+ idata .add_groups ({"optimizer_result" : dataset })
422+
423+ return idata
0 commit comments