Skip to content

Commit 97f0f79

Browse files
Make joint_logp always return lists when sum=False
1 parent 7cc570a commit 97f0f79

File tree

5 files changed

+12
-19
lines changed

5 files changed

+12
-19
lines changed

pymc/distributions/logprob.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,6 @@ def joint_logpt(
246246
logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
247247
else:
248248
logp_var = list(logp_var_dict.values())
249-
# TODO: deprecate special behavior when only one variable is requested and
250-
# always return a list. This is here for backwards compatibility as logpt
251-
# started as a replacement to factor.logpt, but it should now be considered an
252-
# internal function reached only via model.logp* methods.
253-
if len(logp_var) == 1:
254-
logp_var = logp_var[0]
255249

256250
return logp_var
257251

pymc/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,11 +749,10 @@ def logpt(
749749
f"Requested variable {var} not found among the model variables"
750750
)
751751

752-
rv_logps = []
752+
rv_logps: List[TensorVariable] = []
753753
if rv_values:
754754
rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
755-
if not isinstance(rv_logps, list):
756-
rv_logps = [rv_logps]
755+
assert isinstance(rv_logps, list)
757756

758757
# Replace random variables by their value variables in potential terms
759758
potential_logps = []

pymc/tests/test_distributions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2783,19 +2783,19 @@ def test_array_bound(self):
27832783
UpperPoisson = Bound("upper", dist, upper=[np.inf, 10], transform=None)
27842784
BoundedPoisson = Bound("bounded", dist, lower=[1, 2], upper=[9, 10], transform=None)
27852785

2786-
first, second = joint_logpt(LowerPoisson, [0, 0], sum=False).eval()
2786+
first, second = joint_logpt(LowerPoisson, [0, 0], sum=False)[0].eval()
27872787
assert first == -np.inf
27882788
assert second != -np.inf
27892789

2790-
first, second = joint_logpt(UpperPoisson, [11, 11], sum=False).eval()
2790+
first, second = joint_logpt(UpperPoisson, [11, 11], sum=False)[0].eval()
27912791
assert first != -np.inf
27922792
assert second == -np.inf
27932793

2794-
first, second = joint_logpt(BoundedPoisson, [1, 1], sum=False).eval()
2794+
first, second = joint_logpt(BoundedPoisson, [1, 1], sum=False)[0].eval()
27952795
assert first != -np.inf
27962796
assert second == -np.inf
27972797

2798-
first, second = joint_logpt(BoundedPoisson, [10, 10], sum=False).eval()
2798+
first, second = joint_logpt(BoundedPoisson, [10, 10], sum=False)[0].eval()
27992799
assert first == -np.inf
28002800
assert second != -np.inf
28012801

@@ -3285,7 +3285,7 @@ def logp(value, mu):
32853285
a_val = np.random.normal(loc=mu_val, scale=1, size=to_tuple(size) + (supp_shape,)).astype(
32863286
aesara.config.floatX
32873287
)
3288-
log_densityt = joint_logpt(a, a.tag.value_var, sum=False)
3288+
log_densityt = joint_logpt(a, a.tag.value_var, sum=False)[0]
32893289
assert log_densityt.eval(
32903290
{a.tag.value_var: a_val, mu.tag.value_var: mu_val},
32913291
).shape == to_tuple(size)

pymc/tests/test_logprob.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_joint_logpt_basic():
6060

6161
b_logp = joint_logpt(b, b_value_var, sum=False)
6262

63-
res_ancestors = list(walk_model((b_logp,), walk_past_rvs=True))
63+
res_ancestors = list(walk_model(b_logp, walk_past_rvs=True))
6464
res_rv_ancestors = [
6565
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
6666
]
@@ -104,7 +104,7 @@ def test_joint_logpt_incsubtensor(indices, size):
104104

105105
a_idx_logp = joint_logpt(a_idx, {a_idx: a_value_var}, sum=False)
106106

107-
logp_vals = a_idx_logp.eval({a_value_var: a_val})
107+
logp_vals = a_idx_logp[0].eval({a_value_var: a_val})
108108

109109
# The indices that were set should all have the same log-likelihood values,
110110
# because the values they were set to correspond to the unique means along

pymc/tests/test_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def check_transform_elementwise_logp(self, model):
296296
x_val_untransf = at.constant(test_array_untransf).type()
297297

298298
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
299-
assert joint_logpt(x, sum=False).ndim == x.ndim == jacob_det.ndim
299+
assert joint_logpt(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
300300

301301
v1 = joint_logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf})
302302
v2 = joint_logpt(x, x_val_untransf, transformed=False).eval(
@@ -319,10 +319,10 @@ def check_vectortransform_elementwise_logp(self, model):
319319
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
320320
# Original distribution is univariate
321321
if x.owner.op.ndim_supp == 0:
322-
assert joint_logpt(x, sum=False).ndim == x.ndim == (jacob_det.ndim + 1)
322+
assert joint_logpt(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
323323
# Original distribution is multivariate
324324
else:
325-
assert joint_logpt(x, sum=False).ndim == (x.ndim - 1) == jacob_det.ndim
325+
assert joint_logpt(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
326326

327327
a = joint_logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf})
328328
b = joint_logpt(x, x_val_untransf, transformed=False).eval(

0 commit comments

Comments
 (0)