Skip to content

Commit 68d2ff2

Browse files
Rename transformer to CholeskyCorrTransformer
1 parent cf8d9a8 commit 68d2ff2

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
to_tuple,
7575
)
7676
from pymc.distributions.transforms import (
77-
CholeskyCorr,
77+
CholeskyCorrTransform,
7878
Interval,
7979
ZeroSumTransform,
8080
_default_transform,
@@ -1652,9 +1652,9 @@ def logp(value, n, eta):
16521652

16531653
@_default_transform.register(_LKJCorr)
16541654
def lkjcorr_default_transform(op, rv):
1655-
_, _, _, n, *_ = rv.owner.inputs
1655+
rng, shape, n, eta, *_ = rv.owner.inputs = rv.owner.inputs
16561656
n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval
1657-
return CholeskyCorr(n)
1657+
return CholeskyCorrTransform(n=n, upper=False)
16581658

16591659

16601660
class LKJCorr:

pymc/distributions/transforms.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
__all__ = [
3535
"Chain",
3636
"Chain",
37-
"CholeskyCorr",
37+
"CholeskyCorrTransform",
3838
"CholeskyCovPacked",
3939
"CholeskyCovPacked",
4040
"Interval",
@@ -140,7 +140,7 @@ def log_jac_det(self, value, *inputs):
140140
return pt.sum(y, axis=-1)
141141

142142

143-
class CholeskyCorr(Transform):
143+
class CholeskyCorrTransform(Transform):
144144
"""
145145
Map an unconstrained real vector the Cholesky factor of a correlation matrix.
146146
@@ -181,7 +181,7 @@ class CholeskyCorr(Transform):
181181
https://github.com/tensorflow/probability/
182182
"""
183183

184-
name = "cholesky-corr"
184+
name = "cholesky_corr"
185185

186186
def __init__(self, n, upper: bool = False):
187187
"""
@@ -267,15 +267,14 @@ def _fill_triangular_spiral(
267267
upper = self.upper
268268

269269
if unit_diag:
270-
m -= n
271-
n -= 1
270+
n = n - 1
272271

273272
tail = x_raveled[..., n:]
274273

275274
if upper:
276-
xc = pt.concatenate([x_raveled, pt.flip(tail, -1)])
275+
xc = pt.concatenate([x_raveled, pt.flip(tail, -1)], axis=-1)
277276
else:
278-
xc = pt.concatenate([tail, pt.flip(x_raveled, -1)])
277+
xc = pt.concatenate([tail, pt.flip(x_raveled, -1)], axis=-1)
279278

280279
y = pt.reshape(xc, (*batch_shape, n, n))
281280
return pt.triu(y) if upper else pt.tril(y)
@@ -306,8 +305,8 @@ def _inverse_fill_triangular_spiral(
306305
n, m = self.n, self.m
307306

308307
if unit_diag:
309-
m -= n
310-
n -= 1
308+
m = m - n
309+
n = n - 1
311310

312311
upper = self.upper
313312

tests/distributions/test_transform.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def log_jac_det(self, value, *inputs):
667667
m.logp(jacobian=jacobian_val)
668668

669669

670-
class TestLJKCholeskyCorr:
670+
class TestLJKCholeskyCorrTransform:
671671
def _get_test_values(self):
672672
x_unconstrained = np.array([2.0, 2.0, 1.0])
673673
x_constrained = np.array(
@@ -696,7 +696,7 @@ def test_fill_triangular_spiral(self, upper):
696696
]
697697
)
698698

699-
transform = tr.CholeskyCorr(n=3, upper=upper)
699+
transform = tr.CholeskyCorrTransform(n=3, upper=upper)
700700

701701
np.testing.assert_allclose(
702702
transform._fill_triangular_spiral(x_unconstrained, unit_diag=False).eval(),
@@ -709,7 +709,7 @@ def test_fill_triangular_spiral(self, upper):
709709
)
710710

711711
def test_forward(self):
712-
transform = tr.CholeskyCorr(n=3, upper=False)
712+
transform = tr.CholeskyCorrTransform(n=3, upper=False)
713713
x_unconstrained, x_constrained = self._get_test_values()
714714

715715
np.testing.assert_allclose(
@@ -719,7 +719,7 @@ def test_forward(self):
719719
)
720720

721721
def test_backward(self):
722-
transform = tr.CholeskyCorr(n=3, upper=False)
722+
transform = tr.CholeskyCorrTransform(n=3, upper=False)
723723
x_unconstrained, x_constrained = self._get_test_values()
724724

725725
np.testing.assert_allclose(
@@ -729,7 +729,7 @@ def test_backward(self):
729729
)
730730

731731
def test_transform_round_trip(self):
732-
transform = tr.CholeskyCorr(n=3, upper=False)
732+
transform = tr.CholeskyCorrTransform(n=3, upper=False)
733733
x_unconstrained, x_constrained = self._get_test_values()
734734

735735
constrained_reconstructed = transform.backward(transform.forward(x_constrained)).eval()
@@ -739,7 +739,7 @@ def test_transform_round_trip(self):
739739
np.testing.assert_allclose(x_constrained, constrained_reconstructed, atol=1e-6)
740740

741741
def test_log_jac_det(self):
742-
transform = tr.CholeskyCorr(n=3, upper=False)
742+
transform = tr.CholeskyCorrTransform(n=3, upper=False)
743743
x_unconstrained, x_constrained = self._get_test_values()
744744

745745
computed_log_jac_det = transform.log_jac_det(x_unconstrained).eval()

0 commit comments

Comments
 (0)