Skip to content

Commit 1e902dc

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix-incomplete-progressbar
2 parents 93257fc + 00a4ca3 commit 1e902dc

File tree

14 files changed

+183
-118
lines changed

14 files changed

+183
-118
lines changed

.github/workflows/pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
echo "Checking import and version number (on release)"
4242
venv-bdist/bin/python -c "import pymc_extras as pmx; assert pmx.__version__ == '${{ github.ref_name }}'[1:] if '${{ github.ref_type }}' == 'tag' else pmx.__version__; print(pmx.__version__)"
4343
cd ..
44-
- uses: actions/upload-artifact@v3
44+
- uses: actions/upload-artifact@v4
4545
with:
4646
name: artifact
4747
path: dist/*
@@ -58,7 +58,7 @@ jobs:
5858
# write id-token is necessary for trusted publishing (OIDC)
5959
id-token: write
6060
steps:
61-
- uses: actions/download-artifact@v3
61+
- uses: actions/download-artifact@v4
6262
with:
6363
name: artifact
6464
path: dist

conda-envs/environment-test.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ channels:
33
- conda-forge
44
- nodefaults
55
dependencies:
6-
- pymc>=5.19.1
6+
- pymc>=5.20
77
- pytest-cov>=2.5
88
- pytest>=3.0
99
- dask
1010
- xhistogram
1111
- statsmodels
12+
- numba<=0.60.0
1213
- pip
1314
- pip:
1415
- blackjax
1516
- scikit-learn
16-
- better_optimize>=0.0.10
17+
- better_optimize

conda-envs/windows-environment-test.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ dependencies:
99
- dask
1010
- xhistogram
1111
- statsmodels
12+
- numba<=0.60.0
13+
- pymc>=5.20
1214
- pip:
13-
- pymc>=5.19.1 # CI was failing to resolve
1415
- blackjax
1516
- scikit-learn
16-
- better_optimize>=0.0.10
17+
- better_optimize

notebooks/Exponential Trend Smoothing.ipynb

Lines changed: 52 additions & 52 deletions
Large diffs are not rendered by default.

notebooks/Making a Custom Statespace Model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
"\n",
7474
"The statespace module is designed to make it easy for users to create their own statespace models. At its core, a statspace model is just a system of two linear equatons:\n",
7575
"\n",
76-
"$$\\begin{align} x_{t+1} &= A_t x_t + c_t + R_t \\varepsilon_t, & \\varepsilon_t &\\sim N(0, Q_t) \\\\\n",
76+
"$$\\begin{align} x_{t} &= A_t x_{t-1} + c_t + R_t \\varepsilon_t, & \\varepsilon_t &\\sim N(0, Q_t) \\\\\n",
7777
"y_t &= Z_t x_t + d_t + \\eta_t, & \\eta_t &\\sim N(0, H_t) \\\\\n",
7878
"x_0 &\\sim N(\\bar x, P)\\end{align}$$\n",
7979
"\n",

pymc_extras/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
18+
from pymc_extras.inference.find_map import find_MAP
1819
from pymc_extras.inference.fit import fit
20+
from pymc_extras.inference.laplace import fit_laplace
1921
from pymc_extras.model.marginal.marginal_model import (
2022
MarginalModel,
2123
marginalize,

pymc_extras/inference/find_map.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22

33
from collections.abc import Callable
4+
from importlib.util import find_spec
45
from typing import Literal, cast, get_args
56

6-
import jax
77
import numpy as np
88
import pymc as pm
99
import pytensor
@@ -30,13 +30,29 @@
3030
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
3131
method_info = MINIMIZE_MODE_KWARGS[method].copy()
3232

33-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36-
3733
if use_hess and use_hessp:
34+
_log.warning(
35+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37+
'Setting "use_hess" to False.'
38+
)
3839
use_hess = False
3940

41+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42+
43+
if use_hessp is not None and use_hess is None:
44+
use_hess = not use_hessp
45+
46+
elif use_hess is not None and use_hessp is None:
47+
use_hessp = not use_hess
48+
49+
elif use_hessp is None and use_hess is None:
50+
use_hessp = method_info["uses_hessp"]
51+
use_hess = method_info["uses_hess"]
52+
if use_hessp and use_hess:
53+
# If a method could use either hess or hessp, we default to using hessp
54+
use_hess = False
55+
4056
return use_grad, use_hess, use_hessp
4157

4258

@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
5975
The nearest positive semi-definite matrix to the input matrix.
6076
"""
6177
C = (A + A.T) / 2
62-
eigval, eigvec = np.linalg.eig(C)
78+
eigval, eigvec = np.linalg.eigh(C)
6379
eigval[eigval < 0] = 0
6480

6581
return eigvec @ np.diag(eigval) @ eigvec.T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97113
return f_untransform(posterior_draws)
98114

99115

100-
def _compile_jax_gradients(
116+
def _compile_grad_and_hess_to_jax(
101117
f_loss: Function, use_hess: bool, use_hessp: bool
102118
) -> tuple[Callable | None, Callable | None]:
103119
"""
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122138
f_hessp: Callable | None
123139
The compiled hessian-vector product function, or None if use_hessp is False.
124140
"""
141+
import jax
142+
125143
f_hess = None
126144
f_hessp = None
127145

@@ -152,7 +170,7 @@ def f_hess_jax(x):
152170
return f_loss_and_grad, f_hess, f_hessp
153171

154172

155-
def _compile_functions(
173+
def _compile_functions_for_scipy_optimize(
156174
loss: TensorVariable,
157175
inputs: list[TensorVariable],
158176
compute_grad: bool,
@@ -177,7 +195,7 @@ def _compile_functions(
177195
compute_hessp: bool
178196
Whether to compile a function that computes the Hessian-vector product of the loss function.
179197
compile_kwargs: dict, optional
180-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
198+
Additional keyword arguments to pass to the ``pm.compile`` function.
181199
182200
Returns
183201
-------
@@ -193,19 +211,19 @@ def _compile_functions(
193211
if compute_grad:
194212
grads = pytensor.gradient.grad(loss, inputs)
195213
grad = pt.concatenate([grad.ravel() for grad in grads])
196-
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
214+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197215
else:
198-
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
216+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
199217
return [f_loss]
200218

201219
if compute_hess:
202220
hess = pytensor.gradient.jacobian(grad, inputs)[0]
203-
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
221+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
204222

205223
if compute_hessp:
206224
p = pt.tensor("p", shape=inputs[0].type.shape)
207225
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208-
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
226+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209227

210228
return [f_loss_and_grad, f_hess, f_hessp]
211229

@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240258
gradient_backend: str, default "pytensor"
241259
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242260
compile_kwargs:
243-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
261+
Additional keyword arguments to pass to the ``pm.compile`` function.
244262
245263
Returns
246264
-------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265283
)
266284

267285
use_jax_gradients = (gradient_backend == "jax") and use_grad
286+
if use_jax_gradients and not find_spec("jax"):
287+
raise ImportError("JAX must be installed to use JAX gradients")
268288

269289
mode = compile_kwargs.get("mode", None)
270290
if mode is None and use_jax_gradients:
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285305
compute_hess = use_hess and not use_jax_gradients
286306
compute_hessp = use_hessp and not use_jax_gradients
287307

288-
funcs = _compile_functions(
308+
funcs = _compile_functions_for_scipy_optimize(
289309
loss=loss,
290310
inputs=[flat_input],
291311
compute_grad=compute_grad,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301321

302322
if use_jax_gradients:
303323
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304-
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
324+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305325

306326
return f_loss, f_hess, f_hessp
307327

pymc_extras/inference/laplace.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from functools import reduce
19+
from importlib.util import find_spec
1920
from itertools import product
2021
from typing import Literal
2122

@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
231232
return idata
232233

233234

234-
def fit_mvn_to_MAP(
235+
def fit_mvn_at_MAP(
235236
optimized_point: dict[str, np.ndarray],
236237
model: pm.Model | None = None,
237238
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
276277
inverse_hessian: np.ndarray
277278
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
278279
"""
280+
if gradient_backend == "jax" and not find_spec("jax"):
281+
raise ImportError("JAX must be installed to use JAX gradients")
282+
279283
model = pm.modelcontext(model)
280284
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
281285
frozen_model = freeze_dims_and_data(model)
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
344348
345349
Parameters
346350
----------
347-
mu
348-
H_inv
351+
mu: RaveledVars
352+
The MAP estimate of the model parameters.
353+
H_inv: np.ndarray
354+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349355
model : Model
350356
A PyMC model
351357
chains : int
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
384390
constrained_rvs, replace={unconstrained_vector: batched_values}
385391
)
386392

387-
f_constrain = pm.compile_pymc(
388-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389-
)
393+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390394
posterior_draws = f_constrain(posterior_draws)
391395

392396
else:
@@ -472,15 +476,17 @@ def fit_laplace(
472476
and 1).
473477
474478
.. warning::
475-
This argumnet should be considered highly experimental. It has not been verified if this method produces
479+
This argument should be considered highly experimental. It has not been verified if this method produces
476480
valid draws from the posterior. **Use at your own risk**.
477481
478482
gradient_backend: str, default "pytensor"
479483
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480484
chains: int, default: 2
481-
The number of sampling chains running in parallel.
485+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
486+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
487+
compatible with the ArviZ library.
482488
draws: int, default: 500
483-
The number of samples to draw from the approximated posterior.
489+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484490
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485491
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486492
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,11 +553,12 @@ def fit_laplace(
547553
**optimizer_kwargs,
548554
)
549555

550-
mu, H_inv = fit_mvn_to_MAP(
556+
mu, H_inv = fit_mvn_at_MAP(
551557
optimized_point=optimized_point,
552558
model=model,
553559
on_bad_cov=on_bad_cov,
554560
transform_samples=fit_in_unconstrained_space,
561+
gradient_backend=gradient_backend,
555562
zero_tol=zero_tol,
556563
diag_jitter=diag_jitter,
557564
compile_kwargs=compile_kwargs,

pymc_extras/model/marginal/marginal_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
model_free_rv,
2020
model_from_fgraph,
2121
)
22-
from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
22+
from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
23+
from pymc.pytensorf import compile as compile_pymc
2324
from pymc.util import RandomState, _get_seeds_per_chain
2425
from pytensor import In, Out
2526
from pytensor.compile import SharedVariable

pymc_extras/statespace/core/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def compile_statespace(
3030

3131
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
3232

33-
_f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
33+
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
3434

3535
def f(*, draws=1, **params):
3636
if isinstance(steps, pt.Variable):

0 commit comments

Comments
 (0)