Skip to content

Commit 55313ab

Browse files
twieckijessegrabowski
authored andcommitted
Linter fixes.
1 parent 1ade4a4 commit 55313ab

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

pymc/distributions/multivariate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@
7373
rv_size_is_none,
7474
to_tuple,
7575
)
76-
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
76+
from pymc.distributions.transforms import (
77+
CholeskyCorr,
78+
Interval,
79+
ZeroSumTransform,
80+
_default_transform,
81+
)
7782
from pymc.logprob.abstract import _logprob
7883
from pymc.logprob.rewriting import (
7984
specialization_ir_rewrites_db,

pymc/distributions/transforms.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import numpy as np
1818
import pytensor.tensor as pt
19-
import pytensor
2019

2120
from pytensor.graph import Op
2221
from pytensor.npy_2_compat import normalize_axis_tuple
@@ -214,13 +213,10 @@ def forward(self, x, *inputs):
214213
chol = pt.zeros((self.n, self.n), dtype=x.dtype)
215214

216215
# Assign the unconstrained values to the lower triangular part
217-
chol = pt.set_subtensor(
218-
chol[self.tril_r_idxs, self.tril_c_idxs],
219-
x
220-
)
216+
chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x)
221217

222218
# Normalize each row to have unit L2 norm
223-
row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True))
219+
row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True))
224220
chol = chol / row_norms
225221

226222
return chol[self.tril_r_idxs, self.tril_c_idxs]
@@ -241,17 +237,14 @@ def backward(self, y, *inputs):
241237
"""
242238
# Reconstruct the full Cholesky matrix
243239
chol = pt.zeros((self.n, self.n), dtype=y.dtype)
244-
chol = pt.set_subtensor(
245-
chol[self.triu_r_idxs, self.triu_c_idxs],
246-
y
247-
)
240+
chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y)
248241
chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype)
249242

250243
# Perform Cholesky decomposition
251244
chol = pt.linalg.cholesky(chol)
252245

253246
# Extract the unconstrained parameters by normalizing
254-
row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1))
247+
row_norms = pt.sqrt(pt.sum(chol**2, axis=1))
255248
unconstrained = chol / row_norms[:, None]
256249

257250
return unconstrained[self.tril_r_idxs, self.tril_c_idxs]
@@ -262,7 +255,7 @@ def log_jac_det(self, y, *inputs):
262255
263256
The Jacobian determinant for normalization is the product of row norms.
264257
"""
265-
row_norms = pt.sqrt(pt.sum(y ** 2, axis=1))
258+
row_norms = pt.sqrt(pt.sum(y**2, axis=1))
266259
return -pt.sum(pt.log(row_norms), axis=-1)
267260

268261

tests/distributions/test_transform.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
import pymc as pm
2525
import pymc.distributions.transforms as tr
26-
from pymc.distributions.transforms import CholeskyCorr
2726

27+
from pymc.distributions.transforms import CholeskyCorr
2828
from pymc.logprob.basic import transformed_conditional_logp
2929
from pymc.logprob.transforms import Transform
3030
from pymc.pytensorf import floatX, jacobian
@@ -676,7 +676,9 @@ def test_lkjcorr_transform_round_trip():
676676
with pm.Model() as model:
677677
rho = pm.LKJCorr("rho", n=3, eta=2)
678678

679-
trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False)
679+
trace = pm.sample(
680+
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False
681+
)
680682

681683
# Extract the sampled correlation matrices
682684
rho_samples = trace["rho"]
@@ -755,7 +757,7 @@ def test_lkjcorr_invalid_n():
755757

756758
with pytest.raises(TypeError):
757759
# 'n' must be an integer
758-
CholeskyCorr(n='three')
760+
CholeskyCorr(n="three")
759761

760762

761763
def test_lkjcorr_positive_definite():
@@ -765,7 +767,9 @@ def test_lkjcorr_positive_definite():
765767
with pm.Model() as model:
766768
rho = pm.LKJCorr("rho", n=4, eta=2)
767769

768-
trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False)
770+
trace = pm.sample(
771+
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False
772+
)
769773

770774
# Extract the sampled correlation matrices
771775
rho_samples = trace["rho"]
@@ -800,4 +804,4 @@ def test_lkjcorr_round_trip_various_sizes():
800804
reconstructed = transform.backward(y).eval()
801805

802806
# Assert that the original and reconstructed unconstrained parameters are close
803-
assert_allclose(x, reconstructed, atol=1e-6)
807+
assert_allclose(x, reconstructed, atol=1e-6)

0 commit comments

Comments
 (0)