Skip to content

Commit 408adba

Browse files
committed
Linter fixes.
1 parent dcd3a8d commit 408adba

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

pymc/distributions/multivariate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@
6969
rv_size_is_none,
7070
to_tuple,
7171
)
72-
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
72+
from pymc.distributions.transforms import (
73+
CholeskyCorr,
74+
Interval,
75+
ZeroSumTransform,
76+
_default_transform,
77+
)
7378
from pymc.logprob.abstract import _logprob
7479
from pymc.math import kron_diag, kron_dot
7580
from pymc.pytensorf import normalize_rng_param

pymc/distributions/transforms.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import numpy as np
1919
import pytensor.tensor as pt
20-
import pytensor
2120

2221

2322
# ignore mypy error because it somehow considers that
@@ -213,13 +212,10 @@ def forward(self, x, *inputs):
213212
chol = pt.zeros((self.n, self.n), dtype=x.dtype)
214213

215214
# Assign the unconstrained values to the lower triangular part
216-
chol = pt.set_subtensor(
217-
chol[self.tril_r_idxs, self.tril_c_idxs],
218-
x
219-
)
215+
chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x)
220216

221217
# Normalize each row to have unit L2 norm
222-
row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True))
218+
row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True))
223219
chol = chol / row_norms
224220

225221
return chol[self.tril_r_idxs, self.tril_c_idxs]
@@ -240,17 +236,14 @@ def backward(self, y, *inputs):
240236
"""
241237
# Reconstruct the full Cholesky matrix
242238
chol = pt.zeros((self.n, self.n), dtype=y.dtype)
243-
chol = pt.set_subtensor(
244-
chol[self.triu_r_idxs, self.triu_c_idxs],
245-
y
246-
)
239+
chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y)
247240
chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype)
248241

249242
# Perform Cholesky decomposition
250243
chol = pt.linalg.cholesky(chol)
251244

252245
# Extract the unconstrained parameters by normalizing
253-
row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1))
246+
row_norms = pt.sqrt(pt.sum(chol**2, axis=1))
254247
unconstrained = chol / row_norms[:, None]
255248

256249
return unconstrained[self.tril_r_idxs, self.tril_c_idxs]
@@ -261,7 +254,7 @@ def log_jac_det(self, y, *inputs):
261254
262255
The Jacobian determinant for normalization is the product of row norms.
263256
"""
264-
row_norms = pt.sqrt(pt.sum(y ** 2, axis=1))
257+
row_norms = pt.sqrt(pt.sum(y**2, axis=1))
265258
return -pt.sum(pt.log(row_norms), axis=-1)
266259

267260

tests/distributions/test_transform.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
import pymc as pm
2525
import pymc.distributions.transforms as tr
26-
from pymc.distributions.transforms import CholeskyCorr
2726

27+
from pymc.distributions.transforms import CholeskyCorr
2828
from pymc.logprob.basic import transformed_conditional_logp
2929
from pymc.logprob.transforms import Transform
3030
from pymc.pytensorf import floatX, jacobian
@@ -684,7 +684,9 @@ def test_lkjcorr_transform_round_trip():
684684
with pm.Model() as model:
685685
rho = pm.LKJCorr("rho", n=3, eta=2)
686686

687-
trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False)
687+
trace = pm.sample(
688+
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False
689+
)
688690

689691
# Extract the sampled correlation matrices
690692
rho_samples = trace["rho"]
@@ -763,7 +765,7 @@ def test_lkjcorr_invalid_n():
763765

764766
with pytest.raises(TypeError):
765767
# 'n' must be an integer
766-
CholeskyCorr(n='three')
768+
CholeskyCorr(n="three")
767769

768770

769771
def test_lkjcorr_positive_definite():
@@ -773,7 +775,9 @@ def test_lkjcorr_positive_definite():
773775
with pm.Model() as model:
774776
rho = pm.LKJCorr("rho", n=4, eta=2)
775777

776-
trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False)
778+
trace = pm.sample(
779+
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False
780+
)
777781

778782
# Extract the sampled correlation matrices
779783
rho_samples = trace["rho"]
@@ -808,4 +812,4 @@ def test_lkjcorr_round_trip_various_sizes():
808812
reconstructed = transform.backward(y).eval()
809813

810814
# Assert that the original and reconstructed unconstrained parameters are close
811-
assert_allclose(x, reconstructed, atol=1e-6)
815+
assert_allclose(x, reconstructed, atol=1e-6)

0 commit comments

Comments
 (0)