Skip to content

Commit f705d43

Browse files
in-progress refactor
1 parent 923eb26 commit f705d43

File tree

1 file changed

+123
-104
lines changed

1 file changed

+123
-104
lines changed

pymc_experimental/inference/jax_find_map.py

Lines changed: 123 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymc.initial_point import make_initial_point_fn
2424
from pymc.model.transform.conditioning import remove_value_transforms
2525
from pymc.model.transform.optimization import freeze_dims_and_data
26+
from pymc.pytensorf import join_nonshared_inputs
2627
from pymc.sampling.jax import get_jaxified_graph
2728
from pymc.util import get_default_varnames
2829
from pytensor.tensor import TensorVariable
@@ -32,13 +33,12 @@
3233
_log = logging.getLogger(__name__)
3334

3435

35-
def get_near_psd(A: np.ndarray) -> np.ndarray:
36+
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
3637
"""
3738
Compute the nearest positive semi-definite matrix to a given matrix.
3839
39-
This function takes a square matrix and returns the nearest positive
40-
semi-definite matrix using eigenvalue decomposition. It ensures all
41-
eigenvalues are non-negative. The "nearest" matrix is defined in terms
40+
This function takes a square matrix and returns the nearest positive semi-definite matrix using
41+
eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
4242
of the Frobenius norm.
4343
4444
Parameters
@@ -58,23 +58,13 @@ def get_near_psd(A: np.ndarray) -> np.ndarray:
5858
return eigvec @ np.diag(eigval) @ eigvec.T
5959

6060

61-
def _get_unravel_rv_info(optimized_point, variables, model):
62-
cursor = 0
63-
slices = {}
64-
out_shapes = {}
65-
66-
for i, var in enumerate(variables):
67-
raveled_shape = np.prod(optimized_point[var.name].shape).astype(int)
68-
rv = model.values_to_rvs.get(var, var)
69-
70-
idx = slice(cursor, cursor + raveled_shape)
71-
slices[rv] = idx
72-
out_shapes[rv] = tuple(
73-
[len(model.coords[dim]) for dim in model.named_vars_to_dims.get(rv.name, [])]
74-
)
75-
cursor += raveled_shape
61+
def _unconstrained_vector_to_constrained_rvs(model):
62+
constrained_rvs, unconstrained_vector = join_nonshared_inputs(
63+
model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars
64+
)
7665

77-
return slices, out_shapes
66+
unconstrained_vector.name = "unconstrained_vector"
67+
return constrained_rvs, unconstrained_vector
7868

7969

8070
def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws):
@@ -94,37 +84,24 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
9484
return f_untransform(posterior_draws)
9585

9686

97-
def fit_laplace(
87+
def jax_fit_mvn_to_MAP(
9888
optimized_point: dict[str, np.ndarray],
9989
model: pm.Model,
100-
chains: int = 2,
101-
draws: int = 500,
10290
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
10391
transform_samples: bool = True,
10492
zero_tol: float = 1e-8,
10593
diag_jitter: float | None = 1e-8,
106-
progressbar: bool = True,
107-
mode: str = "JAX",
108-
) -> az.InferenceData:
94+
) -> tuple[RaveledVars, np.ndarray]:
10995
"""
110-
Compute the Laplace approximation of the posterior distribution.
111-
112-
The posterior distribution will be approximated as a Gaussian
113-
distribution centered at the posterior mode.
114-
The covariance is the inverse of the negative Hessian matrix of
115-
the log-posterior evaluated at the mode.
96+
Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
97+
evaluated at the MAP estimate. This is the basis of the Laplace approximation.
11698
11799
Parameters
118100
----------
119101
optimized_point : dict[str, np.ndarray]
120-
Local maximum a posteriori (MAP) point returned from pymc.find_MAP
121-
or jax_tools.fit_map
102+
Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
122103
model : Model
123104
A PyMC model
124-
chains : int
125-
The number of sampling chains running in parallel. Default is 2.
126-
draws : int
127-
The number of samples to draw from the approximated posterior. Default is 500.
128105
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
129106
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
130107
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -137,18 +114,17 @@ def fit_laplace(
137114
diag_jitter: float | None
138115
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
139116
If None, no jitter is added. Default is 1e-8.
140-
progressbar : bool
141-
Whether or not to display progress bar. Default is True.
142-
mode : str
143-
Computation backend mode. Default is "JAX".
144117
145118
Returns
146119
-------
147-
InferenceData
148-
arviz.InferenceData object storing posterior, observed_data, and constant_data groups.
120+
map_estimate: RaveledVars
121+
The MAP estimate of the model parameters, raveled into a 1D array.
149122
123+
inverse_hessian: np.ndarray
124+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
150125
"""
151126
frozen_model = freeze_dims_and_data(model)
127+
152128
if not transform_samples:
153129
untransformed_model = remove_value_transforms(frozen_model)
154130
logp = untransformed_model.logp(jacobian=False)
@@ -157,19 +133,17 @@ def fit_laplace(
157133
logp = frozen_model.logp(jacobian=True)
158134
variables = frozen_model.continuous_value_vars
159135

160-
mu = np.concatenate(
161-
[np.atleast_1d(optimized_point[var.name]).ravel() for var in variables], axis=0
136+
mu = DictToArrayBijection.map(optimized_point)
137+
138+
[neg_logp], flat_inputs = join_nonshared_inputs(
139+
point=frozen_model.initial_point(), outputs=[-logp], inputs=variables
162140
)
163141

164142
f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph(
165-
cast(TensorVariable, logp),
166-
use_grad=True,
167-
use_hess=True,
168-
use_hessp=False,
169-
inputs=variables,
143+
neg_logp, use_grad=True, use_hess=True, use_hessp=False, inputs=[flat_inputs]
170144
)
171145

172-
H = f_hess(mu)
146+
H = -f_hess(mu.data)
173147
H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
174148

175149
def stabilize(x, jitter):
@@ -184,73 +158,111 @@ def stabilize(x, jitter):
184158
raise np.linalg.LinAlgError(
185159
"Inverse Hessian not positive-semi definite at the provided point"
186160
)
187-
H_inv = get_near_psd(H_inv)
161+
H_inv = get_nearest_psd(H_inv)
188162
if on_bad_cov == "warn":
189163
_log.warning(
190164
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
191165
"matrix in L1-norm instead"
192166
)
193167

194-
posterior_dist = stats.multivariate_normal(mean=mu, cov=H_inv, allow_singular=True)
168+
return mu, H_inv
169+
170+
171+
def jax_laplace(
172+
mu: RaveledVars,
173+
H_inv: np.ndarray,
174+
model: pm.Model,
175+
chains: int = 2,
176+
draws: int = 500,
177+
transform_samples: bool = True,
178+
progressbar: bool = True,
179+
) -> az.InferenceData:
180+
"""
181+
182+
Parameters
183+
----------
184+
mu
185+
H_inv
186+
model : Model
187+
A PyMC model
188+
chains : int
189+
The number of sampling chains running in parallel. Default is 2.
190+
draws : int
191+
The number of samples to draw from the approximated posterior. Default is 500.
192+
transform_samples : bool
193+
Whether to transform the samples back to the original parameter space. Default is True.
194+
195+
Returns
196+
-------
197+
idata: az.InferenceData
198+
An InferenceData object containing the approximated posterior samples.
199+
"""
200+
posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True)
195201
posterior_draws = posterior_dist.rvs(size=(chains, draws))
196-
slices, out_shapes = _get_unravel_rv_info(optimized_point, variables, frozen_model)
197202

198203
if transform_samples:
199-
posterior_draws = _create_transformed_draws(
200-
H_inv, slices, out_shapes, posterior_draws, frozen_model, chains, draws
201-
)
204+
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
205+
f_constrain = get_jaxified_graph(inputs=[unconstrained_vector], outputs=constrained_rvs)
206+
207+
posterior_draws = jax.jit(jax.vmap(jax.vmap(f_constrain)))(posterior_draws)
208+
202209
else:
210+
info = mu.point_map_info
211+
flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info]
212+
slices = [
213+
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
214+
]
215+
203216
posterior_draws = [
204-
posterior_draws[..., idx].reshape((chains, draws, *out_shapes.get(rv, ())))
205-
for rv, idx in slices.items()
217+
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
218+
for idx, (name, shape, dtype) in zip(slices, info)
206219
]
207220

208-
def make_rv_coords(rv):
221+
def make_rv_coords(name):
209222
coords = {"chain": range(chains), "draw": range(draws)}
210-
extra_dims = frozen_model.named_vars_to_dims.get(rv.name)
223+
extra_dims = model.named_vars_to_dims.get(name)
211224
if extra_dims is None:
212225
return coords
213-
return coords | {dim: list(frozen_model.coords[dim]) for dim in extra_dims}
226+
return coords | {dim: list(model.coords[dim]) for dim in extra_dims}
214227

215-
def make_rv_dims(rv):
228+
def make_rv_dims(name):
216229
dims = ["chain", "draw"]
217-
extra_dims = frozen_model.named_vars_to_dims.get(rv.name)
230+
extra_dims = model.named_vars_to_dims.get(name)
218231
if extra_dims is None:
219232
return dims
220233
return dims + list(extra_dims)
221234

222235
idata = {
223-
rv.name: xr.DataArray(
236+
name: xr.DataArray(
224237
data=draws.squeeze(),
225-
coords=make_rv_coords(rv),
226-
dims=make_rv_dims(rv),
227-
name=rv.name,
238+
coords=make_rv_coords(name),
239+
dims=make_rv_dims(name),
240+
name=name,
228241
)
229-
for rv, draws in zip(slices.keys(), posterior_draws)
242+
for (name, _, _), draws in zip(mu.point_map_info, posterior_draws)
230243
}
231244

232-
coords, dims = coords_and_dims_for_inferencedata(frozen_model)
245+
coords, dims = coords_and_dims_for_inferencedata(model)
233246
idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)
234247

235-
if frozen_model.deterministics:
248+
if model.deterministics:
236249
idata.posterior = pm.compute_deterministics(
237250
idata.posterior,
238-
model=frozen_model,
251+
model=model,
239252
merge_dataset=True,
240253
progressbar=progressbar,
241-
compile_kwargs={"mode": mode},
242254
)
243255

244256
observed_data = dict_to_dataset(
245-
find_observations(frozen_model),
257+
find_observations(model),
246258
library=pm,
247259
coords=coords,
248260
dims=dims,
249261
default_dims=[],
250262
)
251263

252264
constant_data = dict_to_dataset(
253-
find_constants(frozen_model),
265+
find_constants(model),
254266
library=pm,
255267
coords=coords,
256268
dims=dims,
@@ -266,6 +278,29 @@ def make_rv_dims(rv):
266278
return idata
267279

268280

281+
def fit_laplace(
282+
optimized_point: dict[str, np.ndarray],
283+
model: pm.Model,
284+
chains: int = 2,
285+
draws: int = 500,
286+
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
287+
transform_samples: bool = True,
288+
zero_tol: float = 1e-8,
289+
diag_jitter: float | None = 1e-8,
290+
progressbar: bool = True,
291+
) -> az.InferenceData:
292+
mu, H_inv = jax_fit_mvn_to_MAP(
293+
optimized_point,
294+
model,
295+
on_bad_cov,
296+
transform_samples,
297+
zero_tol,
298+
diag_jitter,
299+
)
300+
301+
return jax_laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar)
302+
303+
269304
def make_jax_funcs_from_graph(
270305
graph: TensorVariable,
271306
use_grad: bool,
@@ -280,34 +315,19 @@ def make_jax_funcs_from_graph(
280315
if not isinstance(inputs, list):
281316
inputs = [inputs]
282317

283-
f = cast(Callable, get_jaxified_graph(inputs=inputs, outputs=[graph]))
284-
input_shapes = [x.type.shape for x in inputs]
285-
286-
def at_least_tuple(x):
287-
if isinstance(x, tuple | list):
288-
return x
289-
return (x,)
318+
f_tuple = cast(Callable, get_jaxified_graph(inputs=inputs, outputs=[graph]))
290319

291-
assert all([xi is not None for x in input_shapes for xi in at_least_tuple(x)])
320+
def f(*args, **kwargs):
321+
return f_tuple(*args, **kwargs)[0]
292322

293-
def f_jax(x):
294-
args = []
295-
cursor = 0
296-
for shape in input_shapes:
297-
n_elements = int(np.prod(shape))
298-
s = slice(cursor, cursor + n_elements)
299-
args.append(x[s].reshape(shape))
300-
cursor += n_elements
301-
return f(*args)[0]
302-
303-
f_logp = jax.jit(f_jax)
323+
f_logp = jax.jit(f)
304324

305325
f_grad = None
306326
f_hess = None
307327
f_hessp = None
308328

309329
if use_grad:
310-
_f_grad_jax = jax.grad(f_jax)
330+
_f_grad_jax = jax.grad(f)
311331

312332
def f_grad_jax(x):
313333
return jax.numpy.stack(_f_grad_jax(x))
@@ -411,14 +431,12 @@ def find_MAP(
411431
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
412432
)
413433

414-
inputs = [frozen_model.values_to_rvs[vars_dict[x]] for x in start_dict.keys()]
415-
inputs = [frozen_model.rvs_to_values[x] for x in inputs]
416-
417-
logp_factors = frozen_model.logp(sum=False, jacobian=False)
418-
neg_logp = -pt.sum([pt.sum(factor) for factor in logp_factors])
434+
[neg_logp], inputs = join_nonshared_inputs(
435+
point=start_dict, outputs=[-frozen_model.logp()], inputs=frozen_model.continuous_value_vars
436+
)
419437

420438
f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph(
421-
neg_logp, use_grad, use_hess, use_hessp, inputs=inputs
439+
neg_logp, use_grad, use_hess, use_hessp, inputs=[inputs]
422440
)
423441

424442
args = optimizer_kwargs.pop("args", None)
@@ -435,11 +453,12 @@ def find_MAP(
435453
**optimizer_kwargs,
436454
)
437455

438-
initial_point = RaveledVars(optimizer_result.x, initial_params.point_map_info)
456+
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
439457
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
440458
unobserved_vars_values = model.compile_fn(unobserved_vars)(
441-
DictToArrayBijection.rmap(initial_point, start_dict)
459+
DictToArrayBijection.rmap(raveled_optimized)
442460
)
461+
443462
optimized_point = {
444463
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
445464
}

0 commit comments

Comments
 (0)