diff --git a/pymc_extras/deserialize.py b/pymc_extras/deserialize.py index ac58848e5..244d80e44 100644 --- a/pymc_extras/deserialize.py +++ b/pymc_extras/deserialize.py @@ -13,10 +13,7 @@ from pymc_extras.deserialize import deserialize - prior_class_data = { - "dist": "Normal", - "kwargs": {"mu": 0, "sigma": 1} - } + prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} prior = deserialize(prior_class_data) # Prior("Normal", mu=0, sigma=1) @@ -26,6 +23,7 @@ from pymc_extras.deserialize import register_deserialization + class MyClass: def __init__(self, value: int): self.value = value @@ -34,6 +32,7 @@ def to_dict(self) -> dict: # Example of what the to_dict method might look like. return {"value": self.value} + register_deserialization( is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), deserialize=lambda data: MyClass(value=data["value"]), @@ -80,18 +79,23 @@ class Deserializer: from typing import Any + class MyClass: def __init__(self, value: int): self.value = value + from pymc_extras.deserialize import Deserializer + def is_type(data: Any) -> bool: return data.keys() == {"value"} and isinstance(data["value"], int) + def deserialize(data: dict) -> MyClass: return MyClass(value=data["value"]) + deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize) """ @@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None: from pymc_extras.deserialize import register_deserialization + class MyClass: def __init__(self, value: int): self.value = value @@ -204,6 +209,7 @@ def to_dict(self) -> dict: # Example of what the to_dict method might look like. return {"value": self.value} + register_deserialization( is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), deserialize=lambda data: MyClass(value=data["value"]), diff --git a/pymc_extras/distributions/continuous.py b/pymc_extras/distributions/continuous.py index 0264a9e62..0bcda193e 100644 --- a/pymc_extras/distributions/continuous.py +++ b/pymc_extras/distributions/continuous.py @@ -265,7 +265,7 @@ class Chi: from pymc_extras.distributions import Chi with pm.Model(): - x = Chi('x', nu=1) + x = Chi("x", nu=1) """ @staticmethod diff --git a/pymc_extras/distributions/histogram_utils.py b/pymc_extras/distributions/histogram_utils.py index 5cf899032..2a20d79a4 100644 --- a/pymc_extras/distributions/histogram_utils.py +++ b/pymc_extras/distributions/histogram_utils.py @@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs): ... m = pm.Normal("m", dims="tests") ... s = pm.LogNormal("s", dims="tests") ... pot = pmx.distributions.histogram_approximation( - ... "pot", pm.Normal.dist(m, s), - ... observed=measurements, n_quantiles=50 + ... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50 ... ) For special cases like Zero Inflation in Continuous variables there is a flag. @@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs): ... m = pm.Normal("m", dims="tests") ... s = pm.LogNormal("s", dims="tests") ... pot = pmx.distributions.histogram_approximation( - ... "pot", pm.Normal.dist(m, s), - ... observed=measurements, n_quantiles=50, zero_inflation=True + ... "pot", + ... pm.Normal.dist(m, s), + ... observed=measurements, + ... n_quantiles=50, + ... zero_inflation=True, ... ) """ try: diff --git a/pymc_extras/distributions/multivariate/r2d2m2cp.py b/pymc_extras/distributions/multivariate/r2d2m2cp.py index 71bd5bbad..18a5c53f1 100644 --- a/pymc_extras/distributions/multivariate/r2d2m2cp.py +++ b/pymc_extras/distributions/multivariate/r2d2m2cp.py @@ -305,6 +305,7 @@ def R2D2M2CP( import pymc_extras as pmx import pymc as pm import numpy as np + X = np.random.randn(10, 3) b = np.random.randn(3) y = X @ b + np.random.randn(10) * 0.04 + 5 @@ -339,7 +340,7 @@ def R2D2M2CP( # "c" - a must have in the relation variables_importance=[10, 1, 34], # NOTE: try both - centered=True + centered=True, ) # intercept prior centering should be around prior predictive mean intercept = y.mean() @@ -365,7 +366,7 @@ def R2D2M2CP( r2_std=0.2, # NOTE: if you know where a variable should go # if you do not know, leave as 0.5 - centered=False + centered=False, ) # intercept prior centering should be around prior predictive mean intercept = y.mean() @@ -394,7 +395,7 @@ def R2D2M2CP( # if you do not know, leave as 0.5 positive_probs=[0.8, 0.5, 0.1], # NOTE: try both - centered=True + centered=True, ) intercept = y.mean() obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y) diff --git a/pymc_extras/distributions/timeseries.py b/pymc_extras/distributions/timeseries.py index 19772cff2..6bf3acdb2 100644 --- a/pymc_extras/distributions/timeseries.py +++ b/pymc_extras/distributions/timeseries.py @@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution): with pm.Model() as markov_chain: P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) - init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) - markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) + init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3)) + markov_chain = pmx.DiscreteMarkovChain( + "markov_chain", P=P, init_dist=init_dist, shape=(100,) + ) """ diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index ab5203587..ee6b7ef90 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -354,7 +354,7 @@ def fit_laplace( >>> import numpy as np >>> import pymc as pm >>> import arviz as az - >>> y = np.array([2642, 3503, 4358]*10) + >>> y = np.array([2642, 3503, 4358] * 10) >>> with pm.Model() as m: >>> logsigma = pm.Uniform("logsigma", 1, 100) >>> mu = pm.Uniform("mu", -10000, 10000) diff --git a/pymc_extras/inference/pathfinder/idata.py b/pymc_extras/inference/pathfinder/idata.py index 02e105f3a..e00620d85 100644 --- a/pymc_extras/inference/pathfinder/idata.py +++ b/pymc_extras/inference/pathfinder/idata.py @@ -116,11 +116,10 @@ def pathfinder_result_to_xarray( >>> with pm.Model() as model: ... x = pm.Normal("x", 0, 1) ... y = pm.Normal("y", x, 1, observed=2.0) - ... >>> # Assuming we have a PathfinderResult from a pathfinder run >>> ds = pathfinder_result_to_xarray(result, model=model) >>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc. - >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status + >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status """ data_vars = {} coords = {} @@ -214,9 +213,16 @@ def multipathfinder_result_to_xarray( >>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs >>> ds = multipathfinder_result_to_xarray(result, model=model) >>> print("All data:", ds.data_vars) - >>> print("Summary:", [k for k in ds.data_vars.keys() if not k.startswith(('paths/', 'config/', 'diagnostics/'))]) - >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith('paths/')]) - >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith('config/')]) + >>> print( + ... "Summary:", + ... [ + ... k + ... for k in ds.data_vars.keys() + ... if not k.startswith(("paths/", "config/", "diagnostics/")) + ... ], + ... ) + >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")]) + >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")]) """ n_params = result.samples.shape[-1] if result.samples is not None else None param_coords = get_param_coords(model, n_params) if n_params is not None else None @@ -477,13 +483,16 @@ def add_pathfinder_to_inference_data( >>> with pm.Model() as model: ... x = pm.Normal("x", 0, 1) ... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False) - ... >>> # Assuming we have pathfinder results >>> idata = add_pathfinder_to_inference_data(idata, results, model=model) >>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder'] >>> # Access nested data: - >>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('paths/')]) # Per-path data - >>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('config/')]) # Config data + >>> print( + ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")] + ... ) # Per-path data + >>> print( + ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")] + ... ) # Config data """ # Detect if this is a multi-path result # Use isinstance() as primary check, but fall back to duck typing for compatibility diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d7..99edbc1f8 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -334,7 +334,9 @@ def set_idata_attrs(self, idata=None): >>> model = MyModel(ModelBuilder) >>> idata = az.InferenceData(your_dataset) >>> model.set_idata_attrs(idata=idata) - >>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual + >>> assert ( + ... "id" in idata.attrs + ... ) # this and the following lines are part of doctest, not user manual >>> assert "model_type" in idata.attrs >>> assert "version" in idata.attrs >>> assert "sampler_config" in idata.attrs @@ -382,7 +384,7 @@ def save(self, fname: str) -> None: >>> super().__init__() >>> model = MyModel() >>> model.fit(data) - >>> model.save('model_results.nc') # This will call the overridden method in MyModel + >>> model.save("model_results.nc") # This will call the overridden method in MyModel """ if self.idata is not None and "posterior" in self.idata: file = Path(str(fname)) @@ -432,7 +434,7 @@ def load(cls, fname: str): -------- >>> class MyModel(ModelBuilder): >>> ... - >>> name = './mymodel.nc' + >>> name = "./mymodel.nc" >>> imported_model = MyModel.load(name) """ filepath = Path(str(fname)) @@ -554,7 +556,7 @@ def predict( >>> model = MyModel() >>> idata = model.fit(data) >>> x_pred = [] - >>> prediction_data = pd.DataFrame({'input':x_pred}) + >>> prediction_data = pd.DataFrame({"input": x_pred}) >>> pred_mean = model.predict(prediction_data) """ diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 0f368f986..41494fc2b 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -70,8 +70,10 @@ from pymc_extras.prior import register_tensor_transform + def custom_transform(x): - return x ** 2 + return x**2 + register_tensor_transform("square", custom_transform) @@ -228,8 +230,10 @@ def register_tensor_transform(name: str, transform: Transform) -> None: register_tensor_transform, ) + def custom_transform(x): - return x ** 2 + return x**2 + register_tensor_transform("square", custom_transform) @@ -316,6 +320,7 @@ def sample_prior( from pymc_extras.prior import sample_prior + class CustomVariableDefinition: def __init__(self, dims, n: int): self.dims = dims @@ -323,7 +328,8 @@ def __init__(self, dims, n: int): def create_variable(self, name: str) -> "TensorVariable": x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims) - return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0) + return pt.sum([x**n for n in range(1, self.n + 1)], axis=0) + cubic = CustomVariableDefinition(dims=("channel",), n=3) coords = {"channel": ["C1", "C2", "C3"]} diff --git a/pymc_extras/utils/prior.py b/pymc_extras/utils/prior.py index 15833c9e7..6fee94500 100644 --- a/pymc_extras/utils/prior.py +++ b/pymc_extras/utils/prior.py @@ -176,20 +176,16 @@ def prior_from_idata( >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2: ... priors = prior_from_idata( - ... trace, # the old trace (posterior) - ... var_names=["a", "d"], # take variables as is - ... - ... e="new_e", # assign new name "new_e" for a variable - ... # similar to dict(name="new_e") - ... - ... b=("test", ), # set a dim to "test" - ... # similar to dict(dims=("test", )) - ... - ... c=transforms.log, # apply log transform to a positive variable - ... # similar to dict(transform=transforms.log) - ... - ... # set a name, assign a dim and apply simplex transform - ... f=dict(name="new_f", dims="options", transform=transforms.simplex) + ... trace, # the old trace (posterior) + ... var_names=["a", "d"], # take variables as is + ... e="new_e", # assign new name "new_e" for a variable + ... # similar to dict(name="new_e") + ... b=("test",), # set a dim to "test" + ... # similar to dict(dims=("test", )) + ... c=transforms.log, # apply log transform to a positive variable + ... # similar to dict(transform=transforms.log) + ... # set a name, assign a dim and apply simplex transform + ... f=dict(name="new_f", dims="options", transform=transforms.simplex), ... ) ... trace1 = pm.sample_prior_predictive(100) """ diff --git a/pymc_extras/utils/spline.py b/pymc_extras/utils/spline.py index 921c8ec6f..1a1798ec1 100644 --- a/pymc_extras/utils/spline.py +++ b/pymc_extras/utils/spline.py @@ -97,19 +97,13 @@ def bspline_interpolation(x, *, n=None, eval_points=None, degree=3, sparse=True) -------- >>> import pymc as pm >>> import numpy as np - >>> half_months = np.linspace(0, 365, 12*2) + >>> half_months = np.linspace(0, 365, 12 * 2) >>> with pm.Model(coords=dict(knots_time=half_months, time=np.arange(365))) as model: - ... kernel = pm.gp.cov.ExpQuad(1, ls=365/12) + ... kernel = pm.gp.cov.ExpQuad(1, ls=365 / 12) ... # ready to define gp (a latent process over parameters) - ... gp = pm.gp.gp.Latent( - ... cov_func=kernel - ... ) + ... gp = pm.gp.gp.Latent(cov_func=kernel) ... y_knots = gp.prior("y_knots", half_months[:, None], dims="knots_time") - ... y = pm.Deterministic( - ... "y", - ... bspline_interpolation(y_knots, n=365, degree=3), - ... dims="time" - ... ) + ... y = pm.Deterministic("y", bspline_interpolation(y_knots, n=365, degree=3), dims="time") ... trace = pm.sample_prior_predictive(1) Notes diff --git a/pyproject.toml b/pyproject.toml index 52393e06b..3b864d690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,9 @@ exclude_lines = [ line-length = 100 target-version = "py310" +[tool.ruff.format] +docstring-code-format = true + [tool.ruff.lint] select = ["D", "E", "F", "I", "UP", "W", "RUF"] ignore = [