Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions pymc_extras/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@

from pymc_extras.deserialize import deserialize

prior_class_data = {
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 1}
}
prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
prior = deserialize(prior_class_data)
# Prior("Normal", mu=0, sigma=1)

Expand All @@ -26,6 +23,7 @@

from pymc_extras.deserialize import register_deserialization


class MyClass:
def __init__(self, value: int):
self.value = value
Expand All @@ -34,6 +32,7 @@ def to_dict(self) -> dict:
# Example of what the to_dict method might look like.
return {"value": self.value}


register_deserialization(
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
deserialize=lambda data: MyClass(value=data["value"]),
Expand Down Expand Up @@ -80,18 +79,23 @@ class Deserializer:

from typing import Any


class MyClass:
def __init__(self, value: int):
self.value = value


from pymc_extras.deserialize import Deserializer


def is_type(data: Any) -> bool:
return data.keys() == {"value"} and isinstance(data["value"], int)


def deserialize(data: dict) -> MyClass:
return MyClass(value=data["value"])


deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)

"""
Expand Down Expand Up @@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:

from pymc_extras.deserialize import register_deserialization


class MyClass:
def __init__(self, value: int):
self.value = value
Expand All @@ -204,6 +209,7 @@ def to_dict(self) -> dict:
# Example of what the to_dict method might look like.
return {"value": self.value}


register_deserialization(
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
deserialize=lambda data: MyClass(value=data["value"]),
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class Chi:
from pymc_extras.distributions import Chi

with pm.Model():
x = Chi('x', nu=1)
x = Chi("x", nu=1)
"""

@staticmethod
Expand Down
10 changes: 6 additions & 4 deletions pymc_extras/distributions/histogram_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
... m = pm.Normal("m", dims="tests")
... s = pm.LogNormal("s", dims="tests")
... pot = pmx.distributions.histogram_approximation(
... "pot", pm.Normal.dist(m, s),
... observed=measurements, n_quantiles=50
... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
... )
For special cases like Zero Inflation in Continuous variables there is a flag.
Expand All @@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
... m = pm.Normal("m", dims="tests")
... s = pm.LogNormal("s", dims="tests")
... pot = pmx.distributions.histogram_approximation(
... "pot", pm.Normal.dist(m, s),
... observed=measurements, n_quantiles=50, zero_inflation=True
... "pot",
... pm.Normal.dist(m, s),
... observed=measurements,
... n_quantiles=50,
... zero_inflation=True,
... )
"""
try:
Expand Down
7 changes: 4 additions & 3 deletions pymc_extras/distributions/multivariate/r2d2m2cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def R2D2M2CP(
import pymc_extras as pmx
import pymc as pm
import numpy as np
X = np.random.randn(10, 3)
b = np.random.randn(3)
y = X @ b + np.random.randn(10) * 0.04 + 5
Expand Down Expand Up @@ -339,7 +340,7 @@ def R2D2M2CP(
# "c" - a must have in the relation
variables_importance=[10, 1, 34],
# NOTE: try both
centered=True
centered=True,
)
# intercept prior centering should be around prior predictive mean
intercept = y.mean()
Expand All @@ -365,7 +366,7 @@ def R2D2M2CP(
r2_std=0.2,
# NOTE: if you know where a variable should go
# if you do not know, leave as 0.5
centered=False
centered=False,
)
# intercept prior centering should be around prior predictive mean
intercept = y.mean()
Expand Down Expand Up @@ -394,7 +395,7 @@ def R2D2M2CP(
# if you do not know, leave as 0.5
positive_probs=[0.8, 0.5, 0.1],
# NOTE: try both
centered=True
centered=True,
)
intercept = y.mean()
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
Expand Down
6 changes: 4 additions & 2 deletions pymc_extras/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):

with pm.Model() as markov_chain:
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
markov_chain = pmx.DiscreteMarkovChain(
"markov_chain", P=P, init_dist=init_dist, shape=(100,)
)

"""

Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/inference/laplace_approx/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def fit_laplace(
>>> import numpy as np
>>> import pymc as pm
>>> import arviz as az
>>> y = np.array([2642, 3503, 4358]*10)
>>> y = np.array([2642, 3503, 4358] * 10)
>>> with pm.Model() as m:
>>> logsigma = pm.Uniform("logsigma", 1, 100)
>>> mu = pm.Uniform("mu", -10000, 10000)
Expand Down
25 changes: 17 additions & 8 deletions pymc_extras/inference/pathfinder/idata.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ def pathfinder_result_to_xarray(
>>> with pm.Model() as model:
... x = pm.Normal("x", 0, 1)
... y = pm.Normal("y", x, 1, observed=2.0)
...
>>> # Assuming we have a PathfinderResult from a pathfinder run
>>> ds = pathfinder_result_to_xarray(result, model=model)
>>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
>>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
>>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
"""
data_vars = {}
coords = {}
Expand Down Expand Up @@ -214,9 +213,16 @@ def multipathfinder_result_to_xarray(
>>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
>>> ds = multipathfinder_result_to_xarray(result, model=model)
>>> print("All data:", ds.data_vars)
>>> print("Summary:", [k for k in ds.data_vars.keys() if not k.startswith(('paths/', 'config/', 'diagnostics/'))])
>>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith('paths/')])
>>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith('config/')])
>>> print(
... "Summary:",
... [
... k
... for k in ds.data_vars.keys()
... if not k.startswith(("paths/", "config/", "diagnostics/"))
... ],
... )
>>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
>>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
"""
n_params = result.samples.shape[-1] if result.samples is not None else None
param_coords = get_param_coords(model, n_params) if n_params is not None else None
Expand Down Expand Up @@ -477,13 +483,16 @@ def add_pathfinder_to_inference_data(
>>> with pm.Model() as model:
... x = pm.Normal("x", 0, 1)
... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
...
>>> # Assuming we have pathfinder results
>>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
>>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
>>> # Access nested data:
>>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('paths/')]) # Per-path data
>>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('config/')]) # Config data
>>> print(
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
... ) # Per-path data
>>> print(
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
... ) # Config data
"""
# Detect if this is a multi-path result
# Use isinstance() as primary check, but fall back to duck typing for compatibility
Expand Down
10 changes: 6 additions & 4 deletions pymc_extras/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ def set_idata_attrs(self, idata=None):
>>> model = MyModel(ModelBuilder)
>>> idata = az.InferenceData(your_dataset)
>>> model.set_idata_attrs(idata=idata)
>>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual
>>> assert (
... "id" in idata.attrs
... ) # this and the following lines are part of doctest, not user manual
>>> assert "model_type" in idata.attrs
>>> assert "version" in idata.attrs
>>> assert "sampler_config" in idata.attrs
Expand Down Expand Up @@ -382,7 +384,7 @@ def save(self, fname: str) -> None:
>>> super().__init__()
>>> model = MyModel()
>>> model.fit(data)
>>> model.save('model_results.nc') # This will call the overridden method in MyModel
>>> model.save("model_results.nc") # This will call the overridden method in MyModel
"""
if self.idata is not None and "posterior" in self.idata:
file = Path(str(fname))
Expand Down Expand Up @@ -432,7 +434,7 @@ def load(cls, fname: str):
--------
>>> class MyModel(ModelBuilder):
>>> ...
>>> name = './mymodel.nc'
>>> name = "./mymodel.nc"
>>> imported_model = MyModel.load(name)
"""
filepath = Path(str(fname))
Expand Down Expand Up @@ -554,7 +556,7 @@ def predict(
>>> model = MyModel()
>>> idata = model.fit(data)
>>> x_pred = []
>>> prediction_data = pd.DataFrame({'input':x_pred})
>>> prediction_data = pd.DataFrame({"input": x_pred})
>>> pred_mean = model.predict(prediction_data)
"""

Expand Down
12 changes: 9 additions & 3 deletions pymc_extras/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@
from pymc_extras.prior import register_tensor_transform
def custom_transform(x):
return x ** 2
return x**2
register_tensor_transform("square", custom_transform)
Expand Down Expand Up @@ -228,8 +230,10 @@ def register_tensor_transform(name: str, transform: Transform) -> None:
register_tensor_transform,
)
def custom_transform(x):
return x ** 2
return x**2
register_tensor_transform("square", custom_transform)
Expand Down Expand Up @@ -316,14 +320,16 @@ def sample_prior(
from pymc_extras.prior import sample_prior
class CustomVariableDefinition:
def __init__(self, dims, n: int):
self.dims = dims
self.n = n
def create_variable(self, name: str) -> "TensorVariable":
x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)
cubic = CustomVariableDefinition(dims=("channel",), n=3)
coords = {"channel": ["C1", "C2", "C3"]}
Expand Down
24 changes: 10 additions & 14 deletions pymc_extras/utils/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,16 @@ def prior_from_idata(

>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2:
... priors = prior_from_idata(
... trace, # the old trace (posterior)
... var_names=["a", "d"], # take variables as is
...
... e="new_e", # assign new name "new_e" for a variable
... # similar to dict(name="new_e")
...
... b=("test", ), # set a dim to "test"
... # similar to dict(dims=("test", ))
...
... c=transforms.log, # apply log transform to a positive variable
... # similar to dict(transform=transforms.log)
...
... # set a name, assign a dim and apply simplex transform
... f=dict(name="new_f", dims="options", transform=transforms.simplex)
... trace, # the old trace (posterior)
... var_names=["a", "d"], # take variables as is
... e="new_e", # assign new name "new_e" for a variable
... # similar to dict(name="new_e")
... b=("test",), # set a dim to "test"
... # similar to dict(dims=("test", ))
... c=transforms.log, # apply log transform to a positive variable
... # similar to dict(transform=transforms.log)
... # set a name, assign a dim and apply simplex transform
... f=dict(name="new_f", dims="options", transform=transforms.simplex),
... )
... trace1 = pm.sample_prior_predictive(100)
"""
Expand Down
14 changes: 4 additions & 10 deletions pymc_extras/utils/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,13 @@ def bspline_interpolation(x, *, n=None, eval_points=None, degree=3, sparse=True)
--------
>>> import pymc as pm
>>> import numpy as np
>>> half_months = np.linspace(0, 365, 12*2)
>>> half_months = np.linspace(0, 365, 12 * 2)
>>> with pm.Model(coords=dict(knots_time=half_months, time=np.arange(365))) as model:
... kernel = pm.gp.cov.ExpQuad(1, ls=365/12)
... kernel = pm.gp.cov.ExpQuad(1, ls=365 / 12)
... # ready to define gp (a latent process over parameters)
... gp = pm.gp.gp.Latent(
... cov_func=kernel
... )
... gp = pm.gp.gp.Latent(cov_func=kernel)
... y_knots = gp.prior("y_knots", half_months[:, None], dims="knots_time")
... y = pm.Deterministic(
... "y",
... bspline_interpolation(y_knots, n=365, degree=3),
... dims="time"
... )
... y = pm.Deterministic("y", bspline_interpolation(y_knots, n=365, degree=3), dims="time")
... trace = pm.sample_prior_predictive(1)

Notes
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ exclude_lines = [
line-length = 100
target-version = "py310"

[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint]
select = ["D", "E", "F", "I", "UP", "W", "RUF"]
ignore = [
Expand Down