Skip to content

Commit a542581

Browse files
committed
Add test for scans over sequences
1 parent fafc020 commit a542581

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/logprob/test_scan.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,27 @@ def test_scan_non_pure_rv_output():
478478
grw_logp.eval({grw_vv: grw_vv_test}),
479479
stats.norm.logpdf(np.ones(10)),
480480
)
481+
482+
483+
def test_scan_over_seqs():
484+
"""Test that logprob inference for scans based on sequences (mapping)."""
485+
rng = np.random.default_rng(543)
486+
n_steps = 10
487+
488+
xs = pt.random.normal(size=(n_steps,), name="xs")
489+
ys, _ = pytensor.scan(
490+
fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys"
491+
)
492+
493+
xs_vv = ys.clone()
494+
ys_vv = ys.clone()
495+
ys_logp = factorized_joint_logprob({xs: xs_vv, ys: ys_vv})[ys_vv]
496+
497+
assert_no_rvs(ys_logp)
498+
499+
xs_test = rng.normal(size=(10,))
500+
ys_test = rng.normal(size=(10,))
501+
np.testing.assert_array_almost_equal(
502+
ys_logp.eval({xs_vv: xs_test, ys_vv: ys_test}),
503+
stats.norm.logpdf(ys_test, xs_test),
504+
)

0 commit comments

Comments
 (0)