Skip to content

Commit 937e5fd

Browse files
committed
Fix join logp for multivariate RVs
1 parent 6f90f83 commit 937e5fd

File tree

2 files changed

+25
-35
lines changed

2 files changed

+25
-35
lines changed

pymc/logprob/tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,12 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs):
187187
# If the stacked variables depend on each other, we have to replace them by the respective values
188188
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values)
189189

190-
base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
190+
# Make axis positive and adjust for multivariate logp fewer dimensions to the right
191+
axis = pt.switch(axis >= 0, axis, value.ndim + axis)
192+
axis = pt.minimum(axis, logps[0].ndim - 1)
191193
join_logprob = pt.concatenate(
192194
[pt.atleast_1d(logp) for logp in logps],
193-
axis=axis - base_vars_ndim_supp,
195+
axis=axis,
194196
)
195197

196198
return join_logprob

tests/logprob/test_tensor.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -269,34 +269,23 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate):
269269

270270

271271
@pytest.mark.parametrize(
272-
"size1, supp_size1, size2, supp_size2, axis, concatenate",
272+
"size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis",
273273
[
274-
(None, 2, None, 2, 0, True),
275-
(None, 2, None, 2, -1, True),
276-
((5,), 2, (3,), 2, 0, True),
277-
((5,), 2, (3,), 2, -2, True),
278-
((2,), 5, (2,), 3, 1, True),
279-
pytest.param(
280-
(2,),
281-
5,
282-
(2,),
283-
5,
284-
0,
285-
False,
286-
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"),
287-
),
288-
pytest.param(
289-
(2,),
290-
5,
291-
(2,),
292-
5,
293-
1,
294-
False,
295-
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"),
296-
),
274+
(None, 2, None, 2, 0, True, 0),
275+
(None, 2, None, 2, -1, True, 0),
276+
((5,), 2, (3,), 2, 0, True, 0),
277+
((5,), 2, (3,), 2, -2, True, 0),
278+
((2,), 5, (2,), 3, 1, True, 0),
279+
((5, 6), 10, (5, 1), 10, 1, True, 1),
280+
((5, 6), 10, (5, 1), 10, -2, True, 1),
281+
((2,), 5, (2,), 5, 0, False, 0),
282+
((2,), 5, (2,), 5, 1, False, 1),
283+
((5, 6), 10, (5, 6), 10, 2, False, 2),
297284
],
298285
)
299-
def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis, concatenate):
286+
def test_measurable_join_multivariate(
287+
size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis
288+
):
300289
base1_rv = pt.random.multivariate_normal(
301290
np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1"
302291
)
@@ -310,19 +299,18 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
310299
base1_vv = base1_rv.clone()
311300
base2_vv = base2_rv.clone()
312301
y_vv = y_rv.clone()
302+
303+
y_logp = logp(y_rv, y_vv)
304+
assert_no_rvs(y_logp)
305+
313306
base_logps = [
314307
pt.atleast_1d(logp)
315308
for logp in conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv}).values()
316309
]
317-
318310
if concatenate:
319-
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim)
320-
base_logps = pt.concatenate(base_logps, axis=axis_norm - 1)
311+
expected_logp = pt.concatenate(base_logps, axis=logp_axis)
321312
else:
322-
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1)
323-
base_logps = pt.stack(base_logps, axis=axis_norm - 1)
324-
y_logp = y_logp = logp(y_rv, y_vv)
325-
assert_no_rvs(y_logp)
313+
expected_logp = pt.stack(base_logps, axis=logp_axis)
326314

327315
base1_testval = base1_rv.eval()
328316
base2_testval = base2_rv.eval()
@@ -331,7 +319,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
331319
else:
332320
y_testval = np.stack((base1_testval, base2_testval), axis=axis)
333321
np.testing.assert_allclose(
334-
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
322+
expected_logp.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
335323
y_logp.eval({y_vv: y_testval}),
336324
)
337325

0 commit comments

Comments
 (0)