23
23
from pymc .initial_point import make_initial_point_fn
24
24
from pymc .model .transform .conditioning import remove_value_transforms
25
25
from pymc .model .transform .optimization import freeze_dims_and_data
26
+ from pymc .pytensorf import join_nonshared_inputs
26
27
from pymc .sampling .jax import get_jaxified_graph
27
28
from pymc .util import get_default_varnames
28
29
from pytensor .tensor import TensorVariable
32
33
_log = logging .getLogger (__name__ )
33
34
34
35
35
- def get_near_psd (A : np .ndarray ) -> np .ndarray :
36
+ def get_nearest_psd (A : np .ndarray ) -> np .ndarray :
36
37
"""
37
38
Compute the nearest positive semi-definite matrix to a given matrix.
38
39
39
- This function takes a square matrix and returns the nearest positive
40
- semi-definite matrix using eigenvalue decomposition. It ensures all
41
- eigenvalues are non-negative. The "nearest" matrix is defined in terms
40
+ This function takes a square matrix and returns the nearest positive semi-definite matrix using
41
+ eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
42
42
of the Frobenius norm.
43
43
44
44
Parameters
@@ -58,23 +58,13 @@ def get_near_psd(A: np.ndarray) -> np.ndarray:
58
58
return eigvec @ np .diag (eigval ) @ eigvec .T
59
59
60
60
61
- def _get_unravel_rv_info (optimized_point , variables , model ):
62
- cursor = 0
63
- slices = {}
64
- out_shapes = {}
65
-
66
- for i , var in enumerate (variables ):
67
- raveled_shape = np .prod (optimized_point [var .name ].shape ).astype (int )
68
- rv = model .values_to_rvs .get (var , var )
69
-
70
- idx = slice (cursor , cursor + raveled_shape )
71
- slices [rv ] = idx
72
- out_shapes [rv ] = tuple (
73
- [len (model .coords [dim ]) for dim in model .named_vars_to_dims .get (rv .name , [])]
74
- )
75
- cursor += raveled_shape
61
+ def _unconstrained_vector_to_constrained_rvs (model ):
62
+ constrained_rvs , unconstrained_vector = join_nonshared_inputs (
63
+ model .initial_point (), inputs = model .value_vars , outputs = model .unobserved_value_vars
64
+ )
76
65
77
- return slices , out_shapes
66
+ unconstrained_vector .name = "unconstrained_vector"
67
+ return constrained_rvs , unconstrained_vector
78
68
79
69
80
70
def _create_transformed_draws (H_inv , slices , out_shapes , posterior_draws , model , chains , draws ):
@@ -94,37 +84,24 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
94
84
return f_untransform (posterior_draws )
95
85
96
86
97
- def fit_laplace (
87
+ def jax_fit_mvn_to_MAP (
98
88
optimized_point : dict [str , np .ndarray ],
99
89
model : pm .Model ,
100
- chains : int = 2 ,
101
- draws : int = 500 ,
102
90
on_bad_cov : Literal ["warn" , "error" , "ignore" ] = "ignore" ,
103
91
transform_samples : bool = True ,
104
92
zero_tol : float = 1e-8 ,
105
93
diag_jitter : float | None = 1e-8 ,
106
- progressbar : bool = True ,
107
- mode : str = "JAX" ,
108
- ) -> az .InferenceData :
94
+ ) -> tuple [RaveledVars , np .ndarray ]:
109
95
"""
110
- Compute the Laplace approximation of the posterior distribution.
111
-
112
- The posterior distribution will be approximated as a Gaussian
113
- distribution centered at the posterior mode.
114
- The covariance is the inverse of the negative Hessian matrix of
115
- the log-posterior evaluated at the mode.
96
+ Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
97
+ evaluated at the MAP estimate. This is the basis of the Laplace approximation.
116
98
117
99
Parameters
118
100
----------
119
101
optimized_point : dict[str, np.ndarray]
120
- Local maximum a posteriori (MAP) point returned from pymc.find_MAP
121
- or jax_tools.fit_map
102
+ Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
122
103
model : Model
123
104
A PyMC model
124
- chains : int
125
- The number of sampling chains running in parallel. Default is 2.
126
- draws : int
127
- The number of samples to draw from the approximated posterior. Default is 500.
128
105
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
129
106
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
130
107
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -137,18 +114,17 @@ def fit_laplace(
137
114
diag_jitter: float | None
138
115
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
139
116
If None, no jitter is added. Default is 1e-8.
140
- progressbar : bool
141
- Whether or not to display progress bar. Default is True.
142
- mode : str
143
- Computation backend mode. Default is "JAX".
144
117
145
118
Returns
146
119
-------
147
- InferenceData
148
- arviz.InferenceData object storing posterior, observed_data, and constant_data groups .
120
+ map_estimate: RaveledVars
121
+ The MAP estimate of the model parameters, raveled into a 1D array .
149
122
123
+ inverse_hessian: np.ndarray
124
+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
150
125
"""
151
126
frozen_model = freeze_dims_and_data (model )
127
+
152
128
if not transform_samples :
153
129
untransformed_model = remove_value_transforms (frozen_model )
154
130
logp = untransformed_model .logp (jacobian = False )
@@ -157,19 +133,17 @@ def fit_laplace(
157
133
logp = frozen_model .logp (jacobian = True )
158
134
variables = frozen_model .continuous_value_vars
159
135
160
- mu = np .concatenate (
161
- [np .atleast_1d (optimized_point [var .name ]).ravel () for var in variables ], axis = 0
136
+ mu = DictToArrayBijection .map (optimized_point )
137
+
138
+ [neg_logp ], flat_inputs = join_nonshared_inputs (
139
+ point = frozen_model .initial_point (), outputs = [- logp ], inputs = variables
162
140
)
163
141
164
142
f_logp , f_grad , f_hess , f_hessp = make_jax_funcs_from_graph (
165
- cast (TensorVariable , logp ),
166
- use_grad = True ,
167
- use_hess = True ,
168
- use_hessp = False ,
169
- inputs = variables ,
143
+ neg_logp , use_grad = True , use_hess = True , use_hessp = False , inputs = [flat_inputs ]
170
144
)
171
145
172
- H = f_hess (mu )
146
+ H = - f_hess (mu . data )
173
147
H_inv = np .linalg .pinv (np .where (np .abs (H ) < zero_tol , 0 , - H ))
174
148
175
149
def stabilize (x , jitter ):
@@ -184,73 +158,111 @@ def stabilize(x, jitter):
184
158
raise np .linalg .LinAlgError (
185
159
"Inverse Hessian not positive-semi definite at the provided point"
186
160
)
187
- H_inv = get_near_psd (H_inv )
161
+ H_inv = get_nearest_psd (H_inv )
188
162
if on_bad_cov == "warn" :
189
163
_log .warning (
190
164
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
191
165
"matrix in L1-norm instead"
192
166
)
193
167
194
- posterior_dist = stats .multivariate_normal (mean = mu , cov = H_inv , allow_singular = True )
168
+ return mu , H_inv
169
+
170
+
171
+ def jax_laplace (
172
+ mu : RaveledVars ,
173
+ H_inv : np .ndarray ,
174
+ model : pm .Model ,
175
+ chains : int = 2 ,
176
+ draws : int = 500 ,
177
+ transform_samples : bool = True ,
178
+ progressbar : bool = True ,
179
+ ) -> az .InferenceData :
180
+ """
181
+
182
+ Parameters
183
+ ----------
184
+ mu
185
+ H_inv
186
+ model : Model
187
+ A PyMC model
188
+ chains : int
189
+ The number of sampling chains running in parallel. Default is 2.
190
+ draws : int
191
+ The number of samples to draw from the approximated posterior. Default is 500.
192
+ transform_samples : bool
193
+ Whether to transform the samples back to the original parameter space. Default is True.
194
+
195
+ Returns
196
+ -------
197
+ idata: az.InferenceData
198
+ An InferenceData object containing the approximated posterior samples.
199
+ """
200
+ posterior_dist = stats .multivariate_normal (mean = mu .data , cov = H_inv , allow_singular = True )
195
201
posterior_draws = posterior_dist .rvs (size = (chains , draws ))
196
- slices , out_shapes = _get_unravel_rv_info (optimized_point , variables , frozen_model )
197
202
198
203
if transform_samples :
199
- posterior_draws = _create_transformed_draws (
200
- H_inv , slices , out_shapes , posterior_draws , frozen_model , chains , draws
201
- )
204
+ constrained_rvs , unconstrained_vector = _unconstrained_vector_to_constrained_rvs (model )
205
+ f_constrain = get_jaxified_graph (inputs = [unconstrained_vector ], outputs = constrained_rvs )
206
+
207
+ posterior_draws = jax .jit (jax .vmap (jax .vmap (f_constrain )))(posterior_draws )
208
+
202
209
else :
210
+ info = mu .point_map_info
211
+ flat_shapes = [np .prod (shape ).astype (int ) for _ , shape , _ in info ]
212
+ slices = [
213
+ slice (sum (flat_shapes [:i ]), sum (flat_shapes [: i + 1 ])) for i in range (len (flat_shapes ))
214
+ ]
215
+
203
216
posterior_draws = [
204
- posterior_draws [..., idx ].reshape ((chains , draws , * out_shapes . get ( rv , ())) )
205
- for rv , idx in slices . items ( )
217
+ posterior_draws [..., idx ].reshape ((chains , draws , * shape )). astype ( dtype )
218
+ for idx , ( name , shape , dtype ) in zip ( slices , info )
206
219
]
207
220
208
- def make_rv_coords (rv ):
221
+ def make_rv_coords (name ):
209
222
coords = {"chain" : range (chains ), "draw" : range (draws )}
210
- extra_dims = frozen_model .named_vars_to_dims .get (rv . name )
223
+ extra_dims = model .named_vars_to_dims .get (name )
211
224
if extra_dims is None :
212
225
return coords
213
- return coords | {dim : list (frozen_model .coords [dim ]) for dim in extra_dims }
226
+ return coords | {dim : list (model .coords [dim ]) for dim in extra_dims }
214
227
215
- def make_rv_dims (rv ):
228
+ def make_rv_dims (name ):
216
229
dims = ["chain" , "draw" ]
217
- extra_dims = frozen_model .named_vars_to_dims .get (rv . name )
230
+ extra_dims = model .named_vars_to_dims .get (name )
218
231
if extra_dims is None :
219
232
return dims
220
233
return dims + list (extra_dims )
221
234
222
235
idata = {
223
- rv . name : xr .DataArray (
236
+ name : xr .DataArray (
224
237
data = draws .squeeze (),
225
- coords = make_rv_coords (rv ),
226
- dims = make_rv_dims (rv ),
227
- name = rv . name ,
238
+ coords = make_rv_coords (name ),
239
+ dims = make_rv_dims (name ),
240
+ name = name ,
228
241
)
229
- for rv , draws in zip (slices . keys () , posterior_draws )
242
+ for ( name , _ , _ ), draws in zip (mu . point_map_info , posterior_draws )
230
243
}
231
244
232
- coords , dims = coords_and_dims_for_inferencedata (frozen_model )
245
+ coords , dims = coords_and_dims_for_inferencedata (model )
233
246
idata = az .convert_to_inference_data (idata , coords = coords , dims = dims )
234
247
235
- if frozen_model .deterministics :
248
+ if model .deterministics :
236
249
idata .posterior = pm .compute_deterministics (
237
250
idata .posterior ,
238
- model = frozen_model ,
251
+ model = model ,
239
252
merge_dataset = True ,
240
253
progressbar = progressbar ,
241
- compile_kwargs = {"mode" : mode },
242
254
)
243
255
244
256
observed_data = dict_to_dataset (
245
- find_observations (frozen_model ),
257
+ find_observations (model ),
246
258
library = pm ,
247
259
coords = coords ,
248
260
dims = dims ,
249
261
default_dims = [],
250
262
)
251
263
252
264
constant_data = dict_to_dataset (
253
- find_constants (frozen_model ),
265
+ find_constants (model ),
254
266
library = pm ,
255
267
coords = coords ,
256
268
dims = dims ,
@@ -266,6 +278,29 @@ def make_rv_dims(rv):
266
278
return idata
267
279
268
280
281
+ def fit_laplace (
282
+ optimized_point : dict [str , np .ndarray ],
283
+ model : pm .Model ,
284
+ chains : int = 2 ,
285
+ draws : int = 500 ,
286
+ on_bad_cov : Literal ["warn" , "error" , "ignore" ] = "ignore" ,
287
+ transform_samples : bool = True ,
288
+ zero_tol : float = 1e-8 ,
289
+ diag_jitter : float | None = 1e-8 ,
290
+ progressbar : bool = True ,
291
+ ) -> az .InferenceData :
292
+ mu , H_inv = jax_fit_mvn_to_MAP (
293
+ optimized_point ,
294
+ model ,
295
+ on_bad_cov ,
296
+ transform_samples ,
297
+ zero_tol ,
298
+ diag_jitter ,
299
+ )
300
+
301
+ return jax_laplace (mu , H_inv , model , chains , draws , transform_samples , progressbar )
302
+
303
+
269
304
def make_jax_funcs_from_graph (
270
305
graph : TensorVariable ,
271
306
use_grad : bool ,
@@ -280,34 +315,19 @@ def make_jax_funcs_from_graph(
280
315
if not isinstance (inputs , list ):
281
316
inputs = [inputs ]
282
317
283
- f = cast (Callable , get_jaxified_graph (inputs = inputs , outputs = [graph ]))
284
- input_shapes = [x .type .shape for x in inputs ]
285
-
286
- def at_least_tuple (x ):
287
- if isinstance (x , tuple | list ):
288
- return x
289
- return (x ,)
318
+ f_tuple = cast (Callable , get_jaxified_graph (inputs = inputs , outputs = [graph ]))
290
319
291
- assert all ([xi is not None for x in input_shapes for xi in at_least_tuple (x )])
320
+ def f (* args , ** kwargs ):
321
+ return f_tuple (* args , ** kwargs )[0 ]
292
322
293
- def f_jax (x ):
294
- args = []
295
- cursor = 0
296
- for shape in input_shapes :
297
- n_elements = int (np .prod (shape ))
298
- s = slice (cursor , cursor + n_elements )
299
- args .append (x [s ].reshape (shape ))
300
- cursor += n_elements
301
- return f (* args )[0 ]
302
-
303
- f_logp = jax .jit (f_jax )
323
+ f_logp = jax .jit (f )
304
324
305
325
f_grad = None
306
326
f_hess = None
307
327
f_hessp = None
308
328
309
329
if use_grad :
310
- _f_grad_jax = jax .grad (f_jax )
330
+ _f_grad_jax = jax .grad (f )
311
331
312
332
def f_grad_jax (x ):
313
333
return jax .numpy .stack (_f_grad_jax (x ))
@@ -411,14 +431,12 @@ def find_MAP(
411
431
{var_name : value for var_name , value in start_dict .items () if var_name in vars_dict }
412
432
)
413
433
414
- inputs = [frozen_model .values_to_rvs [vars_dict [x ]] for x in start_dict .keys ()]
415
- inputs = [frozen_model .rvs_to_values [x ] for x in inputs ]
416
-
417
- logp_factors = frozen_model .logp (sum = False , jacobian = False )
418
- neg_logp = - pt .sum ([pt .sum (factor ) for factor in logp_factors ])
434
+ [neg_logp ], inputs = join_nonshared_inputs (
435
+ point = start_dict , outputs = [- frozen_model .logp ()], inputs = frozen_model .continuous_value_vars
436
+ )
419
437
420
438
f_logp , f_grad , f_hess , f_hessp = make_jax_funcs_from_graph (
421
- neg_logp , use_grad , use_hess , use_hessp , inputs = inputs
439
+ neg_logp , use_grad , use_hess , use_hessp , inputs = [ inputs ]
422
440
)
423
441
424
442
args = optimizer_kwargs .pop ("args" , None )
@@ -435,11 +453,12 @@ def find_MAP(
435
453
** optimizer_kwargs ,
436
454
)
437
455
438
- initial_point = RaveledVars (optimizer_result .x , initial_params .point_map_info )
456
+ raveled_optimized = RaveledVars (optimizer_result .x , initial_params .point_map_info )
439
457
unobserved_vars = get_default_varnames (model .unobserved_value_vars , include_transformed )
440
458
unobserved_vars_values = model .compile_fn (unobserved_vars )(
441
- DictToArrayBijection .rmap (initial_point , start_dict )
459
+ DictToArrayBijection .rmap (raveled_optimized )
442
460
)
461
+
443
462
optimized_point = {
444
463
var .name : value for var , value in zip (unobserved_vars , unobserved_vars_values )
445
464
}
0 commit comments