diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c53dde90f..6924c1f58 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -1062,10 +1062,8 @@ def vectorize_over_posterior( if rv in all_rvs ]: rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs]) - if ( - rv not in needed_rvs - and not ({*outputs, *independent_rvs} & set(rv_ancestors)) - and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs} + if rv not in needed_rvs and not ( + {*outputs, *needed_rvs, *independent_rvs} & set(rv_ancestors) ): independent_rvs.append(rv) for rv in independent_rvs: diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 3dd30e14f..bf7b2581b 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample(): atol=0.6 / np.sqrt(10000), ) assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1) + + +def test_vectorize_over_posterior_with_intermediate_rvs(): + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Normal.dist(a) + c = b + 1 + d = pm.Normal.dist(c) + idata = pm.sample_prior_predictive(100, var_names=["a"]) + idata.add_groups({"posterior": idata.prior}) + _, _, vectorized_no_intermediate = vectorize_over_posterior( + outputs=[b, c, d], + posterior=idata.posterior, + input_rvs=[a], + allow_rvs_in_graph=True, + ) + [vectorized_intermediate_rvs] = vectorize_over_posterior( + outputs=[d], + posterior=idata.posterior, + input_rvs=[a], + allow_rvs_in_graph=True, + ) + assert vectorized_no_intermediate.type.shape == (1, 100) + assert vectorized_no_intermediate.type.shape == vectorized_intermediate_rvs.type.shape + a_ancestor1 = get_var_by_name([vectorized_no_intermediate], "a")[0] + a_ancestor2 = get_var_by_name([vectorized_intermediate_rvs], "a")[0] + assert isinstance(a_ancestor1, TensorConstant) + assert np.array_equiv(a_ancestor1.eval(), idata.posterior.a.data) + assert isinstance(a_ancestor2, TensorConstant) + assert np.array_equiv(a_ancestor2.eval(), idata.posterior.a.data)