From 1033d4732715b00f06a781eaa49950ee7265c39c Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 16 Jan 2025 09:50:56 +0100 Subject: [PATCH 1/3] pre-commit update --- .pre-commit-config.yaml | 2 +- docs/source/learn/core_notebooks/pymc_pytensor.ipynb | 2 +- pymc/data.py | 2 +- pymc/distributions/continuous.py | 8 +++----- pymc/distributions/multivariate.py | 10 ++++++---- pymc/gp/cov.py | 3 +-- pymc/gp/util.py | 3 +-- pymc/sampling/jax.py | 5 ++--- pymc/sampling/mcmc.py | 6 +++--- pymc/sampling/population.py | 6 +++--- pymc/step_methods/compound.py | 6 +++--- pymc/step_methods/state.py | 6 +++--- pymc/testing.py | 6 +++--- pymc/variational/opvi.py | 8 ++++---- pymc/variational/updates.py | 2 +- tests/distributions/test_multivariate.py | 12 ++++++------ tests/gp/test_hsgp_approx.py | 12 ++++++------ tests/test_data.py | 6 +++--- 18 files changed, 51 insertions(+), 54 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 10fd36fd94..2ba656a365 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: # - --exclude=binder/ # - --exclude=versioneer.py - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.9.1 hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/docs/source/learn/core_notebooks/pymc_pytensor.ipynb b/docs/source/learn/core_notebooks/pymc_pytensor.ipynb index aad72316a3..0260f960d1 100644 --- a/docs/source/learn/core_notebooks/pymc_pytensor.ipynb +++ b/docs/source/learn/core_notebooks/pymc_pytensor.ipynb @@ -1849,7 +1849,7 @@ "print(\n", " f\"\"\"\n", "mu_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=2)}\n", - "sigma_log_value -> {- 10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n", + "sigma_log_value -> {-10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n", "x_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=np.exp(-10))}\n", "\"\"\"\n", ")" diff --git a/pymc/data.py b/pymc/data.py index 997f0ccb3c..fd2ef8e82c 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -257,7 +257,7 @@ def determine_coords( if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: raise pm.exceptions.ShapeError( - "Invalid data shape. The rank of the dataset must match the " "length of `dims`.", + "Invalid data shape. The rank of the dataset must match the length of `dims`.", actual=value.shape, expected=value.ndim, ) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 3746f90fac..21a683ca99 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -992,8 +992,7 @@ def get_mu_lam_phi(mu, lam, phi): return mu, lam, lam / mu raise ValueError( - "Wald distribution must specify either mu only, " - "mu and lam, mu and phi, or lam and phi." + "Wald distribution must specify either mu only, mu and lam, mu and phi, or lam and phi." ) def logp(value, mu, lam, alpha): @@ -1603,8 +1602,7 @@ def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs): def get_kappa(cls, kappa=None, q=None): if kappa is not None and q is not None: raise ValueError( - "Incompatible parameterization. Either use " - "kappa or q to specify the distribution." + "Incompatible parameterization. Either use kappa or q to specify the distribution." ) elif q is not None: if isinstance(q, Variable): @@ -3483,7 +3481,7 @@ def get_nu_b(cls, nu, b, sigma): elif nu is not None and b is None: b = nu / sigma return nu, b, sigma - raise ValueError("Rice distribution must specify either nu" " or b.") + raise ValueError("Rice distribution must specify either nu or b.") def support_point(rv, size, nu, sigma): nu_sigma_ratio = -(nu**2) / (2 * sigma**2) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index e44008fe65..949c592aba 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -247,7 +247,9 @@ class MvNormal(Continuous): data = np.random.multivariate_normal(mu, true_cov, 10) sd_dist = pm.Exponential.dist(1.0, shape=3) - chol, corr, stds = pm.LKJCholeskyCov("chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True) + chol, corr, stds = pm.LKJCholeskyCov( + "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True + ) vals = pm.MvNormal("vals", mu=mu, chol=chol, observed=data) For unobserved values it can be better to use a non-centered @@ -2793,9 +2795,9 @@ def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs): support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1) - assert n_zerosum_axes == pt.get_vector_length( - support_shape - ), "support_shape has to be as long as n_zerosum_axes" + assert n_zerosum_axes == pt.get_vector_length(support_shape), ( + "support_shape has to be as long as n_zerosum_axes" + ) return super().dist([sigma, support_shape], **kwargs) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index d9f3577280..bc056be515 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -328,8 +328,7 @@ def power_spectral_density(self, omega: TensorLike) -> TensorVariable: check = Counter([isinstance(factor, Covariance) for factor in self._factor_list]) if check.get(True, 0) >= 2: raise NotImplementedError( - "The power spectral density of products of covariance " - "functions is not implemented." + "The power spectral density of products of covariance functions is not implemented." ) return reduce(mul, self._merge_factors_psd(omega)) diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 3aaf85ab54..b2d0486a2b 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -211,8 +211,7 @@ def plot_gp_dist( samples_kwargs = {} if np.any(np.isnan(samples)): warnings.warn( - "There are `nan` entries in the [samples] arguments. " - "The plot will not contain a band!", + "There are `nan` entries in the [samples] arguments. The plot will not contain a band!", UserWarning, ) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43e1baa87f..2823d0cfff 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -108,8 +108,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl if any(var.default_update is not None for var in shared_variables): raise ValueError( - "Graph contains shared variables with default_update which cannot " - "be safely replaced." + "Graph contains shared variables with default_update which cannot be safely replaced." ) replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables} @@ -360,7 +359,7 @@ def _sample_blackjax_nuts( map_fn = jax.vmap else: raise ValueError( - "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' + 'Only supporting the following methods to draw chains: "parallel" or "vectorized"' ) if chains == 1: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7cbb6df26e..64d6829fc8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1000,7 +1000,7 @@ def _sample_return( total_draws = draws_per_chain.sum() _log.info( - f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations ' + f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations " f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) " f"took {t_sampling:.0f} seconds." ) @@ -1062,8 +1062,8 @@ def _sample_return( n_chains = len(mtrace.chains) _log.info( - f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' - f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " + f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {n_tune:_d} tune and {n_draws:_d} draw iterations " + f"({n_tune * n_chains:_d} + {n_draws * n_chains:_d} draws total) " f"took {t_sampling:.0f} seconds." ) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index b8a7ba593a..ab024f1e4f 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -386,9 +386,9 @@ def _prepare_iter_population( # 2. Set up the steppers steppers: list[Step] = [] - assert ( - len(rngs) == nchains - ), f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}" + assert len(rngs) == nchains, ( + f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}" + ) for c, rng in enumerate(rngs): # need independent samplers for each chain # it is important to copy the actual steppers (but not the delta_logp) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 1fcb3d2673..b823a00be8 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -282,9 +282,9 @@ def sampling_state(self) -> DataClassState: @sampling_state.setter def sampling_state(self, state: DataClassState): - assert isinstance( - state, self._state_class - ), f"Invalid sampling state class {type(state)}. Expected {self._state_class}" + assert isinstance(state, self._state_class), ( + f"Invalid sampling state class {type(state)}. Expected {self._state_class}" + ) for method, state_method in zip(self.methods, state.methods): method.sampling_state = state_method diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index ec7bbbae48..db62ffda91 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -90,9 +90,9 @@ def sampling_state(self) -> DataClassState: @sampling_state.setter def sampling_state(self, state: DataClassState): state_class = self._state_class - assert isinstance( - state, state_class - ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" + assert isinstance(state, state_class), ( + f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" + ) for field in fields(state_class): is_tensor_name = field.metadata.get("tensor_name", False) state_val = deepcopy(getattr(state, field.name)) diff --git a/pymc/testing.py b/pymc/testing.py index cc7433980c..5e0fa1ab0c 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -964,9 +964,9 @@ def check_rv_size(self): assert actual == expected_symbolic == expected def validate_tests_list(self): - assert len(self.checks_to_run) == len( - set(self.checks_to_run) - ), "There are duplicates in the list of checks_to_run" + assert len(self.checks_to_run) == len(set(self.checks_to_run)), ( + "There are duplicates in the list of checks_to_run" + ) def seeded_scipy_distribution_builder(dist_name: str) -> Callable: diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 034e2fed87..a054f51e62 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -710,9 +710,9 @@ class Group(WithMemoization): @classmethod def register(cls, sbcls): - assert ( - frozenset(sbcls.__param_spec__) not in cls.__param_registry - ), "Duplicate __param_spec__" + assert frozenset(sbcls.__param_spec__) not in cls.__param_registry, ( + "Duplicate __param_spec__" + ) cls.__param_registry[frozenset(sbcls.__param_spec__)] = sbcls assert sbcls.short_name not in cls.__name_registry, "Duplicate short_name" cls.__name_registry[sbcls.short_name] = sbcls @@ -1234,7 +1234,7 @@ def __init__(self, groups, model=None): for g in groups: if g.group is None: if rest is not None: - raise GroupError("More than one group is specified for " "the rest variables") + raise GroupError("More than one group is specified for the rest variables") else: rest = g else: diff --git a/pymc/variational/updates.py b/pymc/variational/updates.py index 234d307500..6818e12ae4 100644 --- a/pymc/variational/updates.py +++ b/pymc/variational/updates.py @@ -1006,7 +1006,7 @@ def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7): elif ndim in [3, 4, 5]: # Conv{1,2,3}DLayer sum_over = tuple(range(1, ndim)) else: - raise ValueError(f"Unsupported tensor dimensionality {ndim}." "Must specify `norm_axes`") + raise ValueError(f"Unsupported tensor dimensionality {ndim}.Must specify `norm_axes`") dtype = np.dtype(pytensor.config.floatX).type norms = pt.sqrt(pt.sum(pt.sqr(tensor_var), axis=sum_over, keepdims=True)) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index cfd50fdd71..d988718fed 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1531,14 +1531,14 @@ class TestZeroSumNormal: def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): if check_zerosum_axes: for ax in axes_to_check: - assert np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + assert np.isclose(random_samples.mean(axis=ax), 0).all(), ( + f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + ) else: for ax in axes_to_check: - assert not np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + assert not np.isclose(random_samples.mean(axis=ax), 0).all(), ( + f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + ) @pytest.mark.parametrize( "dims, n_zerosum_axes", diff --git a/tests/gp/test_hsgp_approx.py b/tests/gp/test_hsgp_approx.py index db03c8b8bc..b18577cde5 100644 --- a/tests/gp/test_hsgp_approx.py +++ b/tests/gp/test_hsgp_approx.py @@ -135,9 +135,9 @@ def test_mean_invariance(self): with model: pm.set_data({"X": x_new}) - assert np.allclose( - gp._X_center, original_center - ), "gp._X_center should not change after updating data for out-of-sample predictions." + assert np.allclose(gp._X_center, original_center), ( + "gp._X_center should not change after updating data for out-of-sample predictions." + ) def test_parametrization(self): err_msg = ( @@ -188,9 +188,9 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first): n_coeffs = model.f1_hsgp_coeffs.type.shape[0] if drop_first: - assert ( - n_coeffs == n_basis - 1 - ), f"one basis vector should have been dropped, {n_coeffs}" + assert n_coeffs == n_basis - 1, ( + f"one basis vector should have been dropped, {n_coeffs}" + ) else: assert n_coeffs == n_basis, "one was dropped when it shouldn't have been" diff --git a/tests/test_data.py b/tests/test_data.py index 2ba66dc744..695058c87e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -318,8 +318,8 @@ def test_explicit_coords(self, seeded_test): N_cols = 7 data = np.random.uniform(size=(N_rows, N_cols)) coords = { - "rows": [f"R{r+1}" for r in range(N_rows)], - "columns": [f"C{c+1}" for c in range(N_cols)], + "rows": [f"R{r + 1}" for r in range(N_rows)], + "columns": [f"C{c + 1}" for c in range(N_cols)], } # pass coordinates explicitly, use numpy array in Data container with pm.Model(coords=coords) as pmodel: @@ -391,7 +391,7 @@ def test_implicit_coords_dataframe(self, seeded_test): N_cols = 7 df_data = pd.DataFrame() for c in range(N_cols): - df_data[f"Column {c+1}"] = np.random.normal(size=(N_rows,)) + df_data[f"Column {c + 1}"] = np.random.normal(size=(N_rows,)) df_data.index.name = "rows" df_data.columns.name = "columns" From d308e0cbee76538350ffb8c4c772e10de0a017c8 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 16 Jan 2025 10:29:53 +0100 Subject: [PATCH 2/3] Update pymc/variational/updates.py Co-authored-by: Thomas Wiecki --- pymc/variational/updates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/updates.py b/pymc/variational/updates.py index 6818e12ae4..07c241beca 100644 --- a/pymc/variational/updates.py +++ b/pymc/variational/updates.py @@ -1006,7 +1006,7 @@ def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7): elif ndim in [3, 4, 5]: # Conv{1,2,3}DLayer sum_over = tuple(range(1, ndim)) else: - raise ValueError(f"Unsupported tensor dimensionality {ndim}.Must specify `norm_axes`") + raise ValueError(f"Unsupported tensor dimensionality {ndim}. Must specify `norm_axes`") dtype = np.dtype(pytensor.config.floatX).type norms = pt.sqrt(pt.sum(pt.sqr(tensor_var), axis=sum_over, keepdims=True)) From 5b4645313e5e30a1256cd4e000d60dbe95279761 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 16 Jan 2025 10:30:31 +0100 Subject: [PATCH 3/3] Update pymc/sampling/jax.py Co-authored-by: Thomas Wiecki --- pymc/sampling/jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 2823d0cfff..4f8ae2a5af 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -359,7 +359,7 @@ def _sample_blackjax_nuts( map_fn = jax.vmap else: raise ValueError( - 'Only supporting the following methods to draw chains: "parallel" or "vectorized"' + "Only supporting the following methods to draw chains: 'parallel' or 'vectorized'" ) if chains == 1: