Skip to content

Commit 71d3ac0

Browse files
authored
Add ruff format for docstrings (#591)
* add ruff docstring format to pyproject * run pre-commit
1 parent 504ef1c commit 71d3ac0

File tree

12 files changed

+75
-54
lines changed

12 files changed

+75
-54
lines changed

pymc_extras/deserialize.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
1414
from pymc_extras.deserialize import deserialize
1515
16-
prior_class_data = {
17-
"dist": "Normal",
18-
"kwargs": {"mu": 0, "sigma": 1}
19-
}
16+
prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
2017
prior = deserialize(prior_class_data)
2118
# Prior("Normal", mu=0, sigma=1)
2219
@@ -26,6 +23,7 @@
2623
2724
from pymc_extras.deserialize import register_deserialization
2825
26+
2927
class MyClass:
3028
def __init__(self, value: int):
3129
self.value = value
@@ -34,6 +32,7 @@ def to_dict(self) -> dict:
3432
# Example of what the to_dict method might look like.
3533
return {"value": self.value}
3634
35+
3736
register_deserialization(
3837
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
3938
deserialize=lambda data: MyClass(value=data["value"]),
@@ -80,18 +79,23 @@ class Deserializer:
8079
8180
from typing import Any
8281
82+
8383
class MyClass:
8484
def __init__(self, value: int):
8585
self.value = value
8686
87+
8788
from pymc_extras.deserialize import Deserializer
8889
90+
8991
def is_type(data: Any) -> bool:
9092
return data.keys() == {"value"} and isinstance(data["value"], int)
9193
94+
9295
def deserialize(data: dict) -> MyClass:
9396
return MyClass(value=data["value"])
9497
98+
9599
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96100
97101
"""
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
196200
197201
from pymc_extras.deserialize import register_deserialization
198202
203+
199204
class MyClass:
200205
def __init__(self, value: int):
201206
self.value = value
@@ -204,6 +209,7 @@ def to_dict(self) -> dict:
204209
# Example of what the to_dict method might look like.
205210
return {"value": self.value}
206211
212+
207213
register_deserialization(
208214
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209215
deserialize=lambda data: MyClass(value=data["value"]),

pymc_extras/distributions/continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class Chi:
265265
from pymc_extras.distributions import Chi
266266
267267
with pm.Model():
268-
x = Chi('x', nu=1)
268+
x = Chi("x", nu=1)
269269
"""
270270

271271
@staticmethod

pymc_extras/distributions/histogram_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
130130
... m = pm.Normal("m", dims="tests")
131131
... s = pm.LogNormal("s", dims="tests")
132132
... pot = pmx.distributions.histogram_approximation(
133-
... "pot", pm.Normal.dist(m, s),
134-
... observed=measurements, n_quantiles=50
133+
... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
135134
... )
136135
137136
For special cases like Zero Inflation in Continuous variables there is a flag.
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
143142
... m = pm.Normal("m", dims="tests")
144143
... s = pm.LogNormal("s", dims="tests")
145144
... pot = pmx.distributions.histogram_approximation(
146-
... "pot", pm.Normal.dist(m, s),
147-
... observed=measurements, n_quantiles=50, zero_inflation=True
145+
... "pot",
146+
... pm.Normal.dist(m, s),
147+
... observed=measurements,
148+
... n_quantiles=50,
149+
... zero_inflation=True,
148150
... )
149151
"""
150152
try:

pymc_extras/distributions/multivariate/r2d2m2cp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def R2D2M2CP(
305305
import pymc_extras as pmx
306306
import pymc as pm
307307
import numpy as np
308+
308309
X = np.random.randn(10, 3)
309310
b = np.random.randn(3)
310311
y = X @ b + np.random.randn(10) * 0.04 + 5
@@ -339,7 +340,7 @@ def R2D2M2CP(
339340
# "c" - a must have in the relation
340341
variables_importance=[10, 1, 34],
341342
# NOTE: try both
342-
centered=True
343+
centered=True,
343344
)
344345
# intercept prior centering should be around prior predictive mean
345346
intercept = y.mean()
@@ -365,7 +366,7 @@ def R2D2M2CP(
365366
r2_std=0.2,
366367
# NOTE: if you know where a variable should go
367368
# if you do not know, leave as 0.5
368-
centered=False
369+
centered=False,
369370
)
370371
# intercept prior centering should be around prior predictive mean
371372
intercept = y.mean()
@@ -394,7 +395,7 @@ def R2D2M2CP(
394395
# if you do not know, leave as 0.5
395396
positive_probs=[0.8, 0.5, 0.1],
396397
# NOTE: try both
397-
centered=True
398+
centered=True,
398399
)
399400
intercept = y.mean()
400401
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)

pymc_extras/distributions/timeseries.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
113113
114114
with pm.Model() as markov_chain:
115115
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116-
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117-
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
116+
init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
117+
markov_chain = pmx.DiscreteMarkovChain(
118+
"markov_chain", P=P, init_dist=init_dist, shape=(100,)
119+
)
118120
119121
"""
120122

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def fit_laplace(
354354
>>> import numpy as np
355355
>>> import pymc as pm
356356
>>> import arviz as az
357-
>>> y = np.array([2642, 3503, 4358]*10)
357+
>>> y = np.array([2642, 3503, 4358] * 10)
358358
>>> with pm.Model() as m:
359359
>>> logsigma = pm.Uniform("logsigma", 1, 100)
360360
>>> mu = pm.Uniform("mu", -10000, 10000)

pymc_extras/inference/pathfinder/idata.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,10 @@ def pathfinder_result_to_xarray(
116116
>>> with pm.Model() as model:
117117
... x = pm.Normal("x", 0, 1)
118118
... y = pm.Normal("y", x, 1, observed=2.0)
119-
...
120119
>>> # Assuming we have a PathfinderResult from a pathfinder run
121120
>>> ds = pathfinder_result_to_xarray(result, model=model)
122121
>>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
123-
>>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
122+
>>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
124123
"""
125124
data_vars = {}
126125
coords = {}
@@ -214,9 +213,16 @@ def multipathfinder_result_to_xarray(
214213
>>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
215214
>>> ds = multipathfinder_result_to_xarray(result, model=model)
216215
>>> print("All data:", ds.data_vars)
217-
>>> print("Summary:", [k for k in ds.data_vars.keys() if not k.startswith(('paths/', 'config/', 'diagnostics/'))])
218-
>>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith('paths/')])
219-
>>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith('config/')])
216+
>>> print(
217+
... "Summary:",
218+
... [
219+
... k
220+
... for k in ds.data_vars.keys()
221+
... if not k.startswith(("paths/", "config/", "diagnostics/"))
222+
... ],
223+
... )
224+
>>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
225+
>>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
220226
"""
221227
n_params = result.samples.shape[-1] if result.samples is not None else None
222228
param_coords = get_param_coords(model, n_params) if n_params is not None else None
@@ -477,13 +483,16 @@ def add_pathfinder_to_inference_data(
477483
>>> with pm.Model() as model:
478484
... x = pm.Normal("x", 0, 1)
479485
... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
480-
...
481486
>>> # Assuming we have pathfinder results
482487
>>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
483488
>>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
484489
>>> # Access nested data:
485-
>>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('paths/')]) # Per-path data
486-
>>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('config/')]) # Config data
490+
>>> print(
491+
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
492+
... ) # Per-path data
493+
>>> print(
494+
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
495+
... ) # Config data
487496
"""
488497
# Detect if this is a multi-path result
489498
# Use isinstance() as primary check, but fall back to duck typing for compatibility

pymc_extras/model_builder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ def set_idata_attrs(self, idata=None):
334334
>>> model = MyModel(ModelBuilder)
335335
>>> idata = az.InferenceData(your_dataset)
336336
>>> model.set_idata_attrs(idata=idata)
337-
>>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual
337+
>>> assert (
338+
... "id" in idata.attrs
339+
... ) # this and the following lines are part of doctest, not user manual
338340
>>> assert "model_type" in idata.attrs
339341
>>> assert "version" in idata.attrs
340342
>>> assert "sampler_config" in idata.attrs
@@ -382,7 +384,7 @@ def save(self, fname: str) -> None:
382384
>>> super().__init__()
383385
>>> model = MyModel()
384386
>>> model.fit(data)
385-
>>> model.save('model_results.nc') # This will call the overridden method in MyModel
387+
>>> model.save("model_results.nc") # This will call the overridden method in MyModel
386388
"""
387389
if self.idata is not None and "posterior" in self.idata:
388390
file = Path(str(fname))
@@ -432,7 +434,7 @@ def load(cls, fname: str):
432434
--------
433435
>>> class MyModel(ModelBuilder):
434436
>>> ...
435-
>>> name = './mymodel.nc'
437+
>>> name = "./mymodel.nc"
436438
>>> imported_model = MyModel.load(name)
437439
"""
438440
filepath = Path(str(fname))
@@ -554,7 +556,7 @@ def predict(
554556
>>> model = MyModel()
555557
>>> idata = model.fit(data)
556558
>>> x_pred = []
557-
>>> prediction_data = pd.DataFrame({'input':x_pred})
559+
>>> prediction_data = pd.DataFrame({"input": x_pred})
558560
>>> pred_mean = model.predict(prediction_data)
559561
"""
560562

pymc_extras/prior.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@
7070
7171
from pymc_extras.prior import register_tensor_transform
7272
73+
7374
def custom_transform(x):
74-
return x ** 2
75+
return x**2
76+
7577
7678
register_tensor_transform("square", custom_transform)
7779
@@ -228,8 +230,10 @@ def register_tensor_transform(name: str, transform: Transform) -> None:
228230
register_tensor_transform,
229231
)
230232
233+
231234
def custom_transform(x):
232-
return x ** 2
235+
return x**2
236+
233237
234238
register_tensor_transform("square", custom_transform)
235239
@@ -316,14 +320,16 @@ def sample_prior(
316320
317321
from pymc_extras.prior import sample_prior
318322
323+
319324
class CustomVariableDefinition:
320325
def __init__(self, dims, n: int):
321326
self.dims = dims
322327
self.n = n
323328
324329
def create_variable(self, name: str) -> "TensorVariable":
325330
x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
326-
return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)
331+
return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)
332+
327333
328334
cubic = CustomVariableDefinition(dims=("channel",), n=3)
329335
coords = {"channel": ["C1", "C2", "C3"]}

pymc_extras/utils/prior.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,16 @@ def prior_from_idata(
176176
177177
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2:
178178
... priors = prior_from_idata(
179-
... trace, # the old trace (posterior)
180-
... var_names=["a", "d"], # take variables as is
181-
...
182-
... e="new_e", # assign new name "new_e" for a variable
183-
... # similar to dict(name="new_e")
184-
...
185-
... b=("test", ), # set a dim to "test"
186-
... # similar to dict(dims=("test", ))
187-
...
188-
... c=transforms.log, # apply log transform to a positive variable
189-
... # similar to dict(transform=transforms.log)
190-
...
191-
... # set a name, assign a dim and apply simplex transform
192-
... f=dict(name="new_f", dims="options", transform=transforms.simplex)
179+
... trace, # the old trace (posterior)
180+
... var_names=["a", "d"], # take variables as is
181+
... e="new_e", # assign new name "new_e" for a variable
182+
... # similar to dict(name="new_e")
183+
... b=("test",), # set a dim to "test"
184+
... # similar to dict(dims=("test", ))
185+
... c=transforms.log, # apply log transform to a positive variable
186+
... # similar to dict(transform=transforms.log)
187+
... # set a name, assign a dim and apply simplex transform
188+
... f=dict(name="new_f", dims="options", transform=transforms.simplex),
193189
... )
194190
... trace1 = pm.sample_prior_predictive(100)
195191
"""

0 commit comments

Comments
 (0)