Skip to content

Commit e6767ab

Browse files
juanitorduztwiecki
andauthored
pre-commit update ruff 0.9.1 (#7648)
Co-authored-by: Thomas Wiecki <[email protected]>
1 parent bd519d4 commit e6767ab

File tree

18 files changed

+51
-54
lines changed

18 files changed

+51
-54
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ repos:
4848
# - --exclude=binder/
4949
# - --exclude=versioneer.py
5050
- repo: https://github.com/astral-sh/ruff-pre-commit
51-
rev: v0.8.4
51+
rev: v0.9.1
5252
hooks:
5353
- id: ruff
5454
args: [--fix, --show-fixes]

docs/source/learn/core_notebooks/pymc_pytensor.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,7 @@
18491849
"print(\n",
18501850
" f\"\"\"\n",
18511851
"mu_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=2)}\n",
1852-
"sigma_log_value -> {- 10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n",
1852+
"sigma_log_value -> {-10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n",
18531853
"x_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=np.exp(-10))}\n",
18541854
"\"\"\"\n",
18551855
")"

pymc/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def determine_coords(
257257
if isinstance(value, np.ndarray) and dims is not None:
258258
if len(dims) != value.ndim:
259259
raise pm.exceptions.ShapeError(
260-
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
260+
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
261261
actual=value.shape,
262262
expected=value.ndim,
263263
)

pymc/distributions/continuous.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -992,8 +992,7 @@ def get_mu_lam_phi(mu, lam, phi):
992992
return mu, lam, lam / mu
993993

994994
raise ValueError(
995-
"Wald distribution must specify either mu only, "
996-
"mu and lam, mu and phi, or lam and phi."
995+
"Wald distribution must specify either mu only, mu and lam, mu and phi, or lam and phi."
997996
)
998997

999998
def logp(value, mu, lam, alpha):
@@ -1603,8 +1602,7 @@ def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs):
16031602
def get_kappa(cls, kappa=None, q=None):
16041603
if kappa is not None and q is not None:
16051604
raise ValueError(
1606-
"Incompatible parameterization. Either use "
1607-
"kappa or q to specify the distribution."
1605+
"Incompatible parameterization. Either use kappa or q to specify the distribution."
16081606
)
16091607
elif q is not None:
16101608
if isinstance(q, Variable):
@@ -3483,7 +3481,7 @@ def get_nu_b(cls, nu, b, sigma):
34833481
elif nu is not None and b is None:
34843482
b = nu / sigma
34853483
return nu, b, sigma
3486-
raise ValueError("Rice distribution must specify either nu" " or b.")
3484+
raise ValueError("Rice distribution must specify either nu or b.")
34873485

34883486
def support_point(rv, size, nu, sigma):
34893487
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)

pymc/distributions/multivariate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ class MvNormal(Continuous):
247247
data = np.random.multivariate_normal(mu, true_cov, 10)
248248
249249
sd_dist = pm.Exponential.dist(1.0, shape=3)
250-
chol, corr, stds = pm.LKJCholeskyCov("chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True)
250+
chol, corr, stds = pm.LKJCholeskyCov(
251+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
252+
)
251253
vals = pm.MvNormal("vals", mu=mu, chol=chol, observed=data)
252254
253255
For unobserved values it can be better to use a non-centered
@@ -2793,9 +2795,9 @@ def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs):
27932795

27942796
support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1)
27952797

2796-
assert n_zerosum_axes == pt.get_vector_length(
2797-
support_shape
2798-
), "support_shape has to be as long as n_zerosum_axes"
2798+
assert n_zerosum_axes == pt.get_vector_length(support_shape), (
2799+
"support_shape has to be as long as n_zerosum_axes"
2800+
)
27992801

28002802
return super().dist([sigma, support_shape], **kwargs)
28012803

pymc/gp/cov.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,7 @@ def power_spectral_density(self, omega: TensorLike) -> TensorVariable:
328328
check = Counter([isinstance(factor, Covariance) for factor in self._factor_list])
329329
if check.get(True, 0) >= 2:
330330
raise NotImplementedError(
331-
"The power spectral density of products of covariance "
332-
"functions is not implemented."
331+
"The power spectral density of products of covariance functions is not implemented."
333332
)
334333
return reduce(mul, self._merge_factors_psd(omega))
335334

pymc/gp/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ def plot_gp_dist(
211211
samples_kwargs = {}
212212
if np.any(np.isnan(samples)):
213213
warnings.warn(
214-
"There are `nan` entries in the [samples] arguments. "
215-
"The plot will not contain a band!",
214+
"There are `nan` entries in the [samples] arguments. The plot will not contain a band!",
216215
UserWarning,
217216
)
218217

pymc/sampling/jax.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
108108

109109
if any(var.default_update is not None for var in shared_variables):
110110
raise ValueError(
111-
"Graph contains shared variables with default_update which cannot "
112-
"be safely replaced."
111+
"Graph contains shared variables with default_update which cannot be safely replaced."
113112
)
114113

115114
replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
@@ -360,7 +359,7 @@ def _sample_blackjax_nuts(
360359
map_fn = jax.vmap
361360
else:
362361
raise ValueError(
363-
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
362+
"Only supporting the following methods to draw chains: 'parallel' or 'vectorized'"
364363
)
365364

366365
if chains == 1:

pymc/sampling/mcmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def _sample_return(
10001000
total_draws = draws_per_chain.sum()
10011001

10021002
_log.info(
1003-
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
1003+
f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations "
10041004
f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
10051005
f"took {t_sampling:.0f} seconds."
10061006
)
@@ -1062,8 +1062,8 @@ def _sample_return(
10621062

10631063
n_chains = len(mtrace.chains)
10641064
_log.info(
1065-
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
1066-
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
1065+
f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {n_tune:_d} tune and {n_draws:_d} draw iterations "
1066+
f"({n_tune * n_chains:_d} + {n_draws * n_chains:_d} draws total) "
10671067
f"took {t_sampling:.0f} seconds."
10681068
)
10691069

pymc/sampling/population.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,9 @@ def _prepare_iter_population(
386386

387387
# 2. Set up the steppers
388388
steppers: list[Step] = []
389-
assert (
390-
len(rngs) == nchains
391-
), f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}"
389+
assert len(rngs) == nchains, (
390+
f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}"
391+
)
392392
for c, rng in enumerate(rngs):
393393
# need independent samplers for each chain
394394
# it is important to copy the actual steppers (but not the delta_logp)

0 commit comments

Comments
 (0)