Skip to content

Commit 5f3f5ec

Browse files
committed
Merge branch 'main' into jitter_scale
2 parents c747e63 + e6767ab commit 5f3f5ec

File tree

19 files changed

+79
-69
lines changed

19 files changed

+79
-69
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ repos:
4848
# - --exclude=binder/
4949
# - --exclude=versioneer.py
5050
- repo: https://github.com/astral-sh/ruff-pre-commit
51-
rev: v0.8.4
51+
rev: v0.9.1
5252
hooks:
5353
- id: ruff
5454
args: [--fix, --show-fixes]

docs/source/learn/core_notebooks/pymc_pytensor.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,7 @@
18491849
"print(\n",
18501850
" f\"\"\"\n",
18511851
"mu_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=2)}\n",
1852-
"sigma_log_value -> {- 10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n",
1852+
"sigma_log_value -> {-10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n",
18531853
"x_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=np.exp(-10))}\n",
18541854
"\"\"\"\n",
18551855
")"

pymc/backends/zarr.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515
from typing import Any
1616

1717
import arviz as az
18-
import numcodecs
1918
import numpy as np
2019
import xarray as xr
21-
import zarr
2220

2321
from arviz.data.base import make_attrs
2422
from arviz.data.inference_data import WARMUP_TAG
25-
from numcodecs.abc import Codec
2623
from pytensor.tensor.variable import TensorVariable
2724

2825
import pymc
@@ -44,11 +41,23 @@
4441
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name
4542

4643
try:
44+
import numcodecs
45+
import zarr
46+
47+
from numcodecs.abc import Codec
48+
from zarr import Group
4749
from zarr.storage import BaseStore, default_compressor
4850
from zarr.sync import Synchronizer
4951

5052
_zarr_available = True
5153
except ImportError:
54+
from typing import TYPE_CHECKING, TypeVar
55+
56+
if not TYPE_CHECKING:
57+
Codec = TypeVar("Codec")
58+
Group = TypeVar("Group")
59+
BaseStore = TypeVar("BaseStore")
60+
Synchronizer = TypeVar("Synchronizer")
5261
_zarr_available = False
5362

5463

@@ -243,7 +252,7 @@ def flush(self):
243252

244253
def get_initial_fill_value_and_codec(
245254
dtype: Any,
246-
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
255+
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]:
247256
_dtype = np.dtype(dtype)
248257
fill_value: FILL_VALUE_TYPE = None
249258
codec = None
@@ -366,27 +375,27 @@ def groups(self) -> list[str]:
366375
return [str(group_name) for group_name, _ in self.root.groups()]
367376

368377
@property
369-
def posterior(self) -> zarr.Group:
378+
def posterior(self) -> Group:
370379
return self.root.posterior
371380

372381
@property
373-
def unconstrained_posterior(self) -> zarr.Group:
382+
def unconstrained_posterior(self) -> Group:
374383
return self.root.unconstrained_posterior
375384

376385
@property
377-
def sample_stats(self) -> zarr.Group:
386+
def sample_stats(self) -> Group:
378387
return self.root.sample_stats
379388

380389
@property
381-
def constant_data(self) -> zarr.Group:
390+
def constant_data(self) -> Group:
382391
return self.root.constant_data
383392

384393
@property
385-
def observed_data(self) -> zarr.Group:
394+
def observed_data(self) -> Group:
386395
return self.root.observed_data
387396

388397
@property
389-
def _sampling_state(self) -> zarr.Group:
398+
def _sampling_state(self) -> Group:
390399
return self.root._sampling_state
391400

392401
def init_trace(
@@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int):
646655

647656
def init_group_with_empty(
648657
self,
649-
group: zarr.Group,
658+
group: Group,
650659
var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
651660
chains: int,
652661
draws: int,
653662
extra_var_attrs: dict | None = None,
654-
) -> zarr.Group:
663+
) -> Group:
655664
group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
656665
for name, (_dtype, shape) in var_dtype_and_shape.items():
657666
fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
@@ -689,8 +698,8 @@ def init_group_with_empty(
689698
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
690699
return group
691700

692-
def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None:
693-
group: zarr.Group | None = None
701+
def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None:
702+
group: Group | None = None
694703
if data_dict:
695704
group_coords = {}
696705
group = self.root.create_group(name=name, overwrite=True)

pymc/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def determine_coords(
257257
if isinstance(value, np.ndarray) and dims is not None:
258258
if len(dims) != value.ndim:
259259
raise pm.exceptions.ShapeError(
260-
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
260+
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
261261
actual=value.shape,
262262
expected=value.ndim,
263263
)

pymc/distributions/continuous.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -992,8 +992,7 @@ def get_mu_lam_phi(mu, lam, phi):
992992
return mu, lam, lam / mu
993993

994994
raise ValueError(
995-
"Wald distribution must specify either mu only, "
996-
"mu and lam, mu and phi, or lam and phi."
995+
"Wald distribution must specify either mu only, mu and lam, mu and phi, or lam and phi."
997996
)
998997

999998
def logp(value, mu, lam, alpha):
@@ -1603,8 +1602,7 @@ def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs):
16031602
def get_kappa(cls, kappa=None, q=None):
16041603
if kappa is not None and q is not None:
16051604
raise ValueError(
1606-
"Incompatible parameterization. Either use "
1607-
"kappa or q to specify the distribution."
1605+
"Incompatible parameterization. Either use kappa or q to specify the distribution."
16081606
)
16091607
elif q is not None:
16101608
if isinstance(q, Variable):
@@ -3483,7 +3481,7 @@ def get_nu_b(cls, nu, b, sigma):
34833481
elif nu is not None and b is None:
34843482
b = nu / sigma
34853483
return nu, b, sigma
3486-
raise ValueError("Rice distribution must specify either nu" " or b.")
3484+
raise ValueError("Rice distribution must specify either nu or b.")
34873485

34883486
def support_point(rv, size, nu, sigma):
34893487
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)

pymc/distributions/multivariate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ class MvNormal(Continuous):
247247
data = np.random.multivariate_normal(mu, true_cov, 10)
248248
249249
sd_dist = pm.Exponential.dist(1.0, shape=3)
250-
chol, corr, stds = pm.LKJCholeskyCov("chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True)
250+
chol, corr, stds = pm.LKJCholeskyCov(
251+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
252+
)
251253
vals = pm.MvNormal("vals", mu=mu, chol=chol, observed=data)
252254
253255
For unobserved values it can be better to use a non-centered
@@ -2793,9 +2795,9 @@ def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs):
27932795

27942796
support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1)
27952797

2796-
assert n_zerosum_axes == pt.get_vector_length(
2797-
support_shape
2798-
), "support_shape has to be as long as n_zerosum_axes"
2798+
assert n_zerosum_axes == pt.get_vector_length(support_shape), (
2799+
"support_shape has to be as long as n_zerosum_axes"
2800+
)
27992801

28002802
return super().dist([sigma, support_shape], **kwargs)
28012803

pymc/gp/cov.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,7 @@ def power_spectral_density(self, omega: TensorLike) -> TensorVariable:
328328
check = Counter([isinstance(factor, Covariance) for factor in self._factor_list])
329329
if check.get(True, 0) >= 2:
330330
raise NotImplementedError(
331-
"The power spectral density of products of covariance "
332-
"functions is not implemented."
331+
"The power spectral density of products of covariance functions is not implemented."
333332
)
334333
return reduce(mul, self._merge_factors_psd(omega))
335334

pymc/gp/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ def plot_gp_dist(
211211
samples_kwargs = {}
212212
if np.any(np.isnan(samples)):
213213
warnings.warn(
214-
"There are `nan` entries in the [samples] arguments. "
215-
"The plot will not contain a band!",
214+
"There are `nan` entries in the [samples] arguments. The plot will not contain a band!",
216215
UserWarning,
217216
)
218217

pymc/sampling/jax.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
108108

109109
if any(var.default_update is not None for var in shared_variables):
110110
raise ValueError(
111-
"Graph contains shared variables with default_update which cannot "
112-
"be safely replaced."
111+
"Graph contains shared variables with default_update which cannot be safely replaced."
113112
)
114113

115114
replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
@@ -360,7 +359,7 @@ def _sample_blackjax_nuts(
360359
map_fn = jax.vmap
361360
else:
362361
raise ValueError(
363-
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
362+
"Only supporting the following methods to draw chains: 'parallel' or 'vectorized'"
364363
)
365364

366365
if chains == 1:

pymc/sampling/mcmc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from rich.theme import Theme
4242
from threadpoolctl import threadpool_limits
4343
from typing_extensions import Protocol
44-
from zarr.storage import MemoryStore
4544

4645
import pymc as pm
4746

@@ -80,6 +79,11 @@
8079
)
8180
from pymc.vartypes import discrete_types
8281

82+
try:
83+
from zarr.storage import MemoryStore
84+
except ImportError:
85+
MemoryStore = type("MemoryStore", (), {})
86+
8387
sys.setrecursionlimit(10000)
8488

8589
__all__ = [
@@ -996,7 +1000,7 @@ def _sample_return(
9961000
total_draws = draws_per_chain.sum()
9971001

9981002
_log.info(
999-
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
1003+
f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations "
10001004
f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
10011005
f"took {t_sampling:.0f} seconds."
10021006
)
@@ -1058,8 +1062,8 @@ def _sample_return(
10581062

10591063
n_chains = len(mtrace.chains)
10601064
_log.info(
1061-
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
1062-
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
1065+
f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {n_tune:_d} tune and {n_draws:_d} draw iterations "
1066+
f"({n_tune * n_chains:_d} + {n_draws * n_chains:_d} draws total) "
10631067
f"took {t_sampling:.0f} seconds."
10641068
)
10651069

0 commit comments

Comments
 (0)