Skip to content

Commit d8cc131

Browse files
committed
Generalize ordered transform
1 parent d34ed95 commit d8cc131

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

pymc/distributions/transforms.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,48 @@ def log_jac_det(self, value, *inputs):
8989

9090

9191
class Ordered(Transform):
92+
"""
93+
Transforms a vector of values into a vector of ordered values.
94+
95+
Parameters
96+
----------
97+
positive: If True, all values are positive. This has better geometry than just chaining with a log transform.
98+
ascending: If True, the values are in ascending order (default). If False, the values are in descending order.
99+
"""
100+
92101
name = "ordered"
93102

94-
def __init__(self, ndim_supp=None):
103+
def __init__(self, ndim_supp=None, positive=False, ascending=True):
95104
if ndim_supp is not None:
96105
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
106+
self.positive = positive
107+
self.ascending = ascending
97108

98109
def backward(self, value, *inputs):
99-
x = pt.zeros(value.shape)
100-
x = pt.set_subtensor(x[..., 0], value[..., 0])
101-
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
102-
return pt.cumsum(x, axis=-1)
110+
if self.positive:
111+
x = pt.exp(value)
112+
else: # Everything except the first element is positive
113+
x = pt.zeros(value.shape)
114+
x = pt.set_subtensor(x[..., 0], value[..., 0])
115+
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
116+
x = pt.cumsum(x, axis=-1)
117+
if not self.ascending:
118+
x = x[..., ::-1]
119+
return x
103120

104121
def forward(self, value, *inputs):
122+
if not self.ascending:
123+
value = value[..., ::-1]
105124
y = pt.zeros(value.shape)
106-
y = pt.set_subtensor(y[..., 0], value[..., 0])
125+
y = pt.set_subtensor(y[..., 0], pt.log(value[..., 0]) if self.positive else value[..., 0])
107126
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
108127
return y
109128

110129
def log_jac_det(self, value, *inputs):
111-
return pt.sum(value[..., 1:], axis=-1)
130+
if self.positive:
131+
return pt.sum(value, axis=-1)
132+
else:
133+
return pt.sum(value[..., 1:], axis=-1)
112134

113135

114136
class SumTo1(Transform):

tests/distributions/test_transform.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,18 @@ def test_ordered():
281281
vals = get_values(tr.ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
282282
assert_array_equal(np.diff(vals) >= 0, True)
283283

284+
# Check that positive=True creates positive and still ordered values
285+
vals = get_values(tr.Ordered(positive=True), Vector(R, 3), pt.vector, floatX(np.zeros(3)))
286+
assert_array_equal(vals > 0, True)
287+
assert_array_equal(np.diff(vals) >= 0, True)
288+
289+
# Check that positive=True and ascending=False creates descending values
290+
vals = get_values(
291+
tr.Ordered(positive=True, ascending=False), Vector(R, 3), pt.vector, floatX(np.zeros(3))
292+
)
293+
assert_array_equal(vals > 0, True)
294+
assert_array_equal(np.diff(vals) <= 0, True)
295+
284296

285297
def test_chain_values():
286298
chain_tranf = tr.Chain([tr.logodds, tr.ordered])

0 commit comments

Comments
 (0)