Skip to content

Commit 5dcd101

Browse files
committed
Fix logprob inference for scans with carried deterministic states
1 parent a542581 commit 5dcd101

File tree

2 files changed

+97
-23
lines changed

2 files changed

+97
-23
lines changed

pymc/logprob/scan.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@
4242
import pytensor.tensor as pt
4343

4444
from pytensor.graph.basic import Variable
45-
from pytensor.graph.fg import FunctionGraph
4645
from pytensor.graph.op import compute_test_value
4746
from pytensor.graph.rewriting.basic import node_rewriter
4847
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
4948
from pytensor.scan.op import Scan
5049
from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2
5150
from pytensor.scan.utils import ScanArgs
5251
from pytensor.tensor.random.type import RandomType
53-
from pytensor.tensor.rewriting.shape import ShapeFeature
5452
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
5553
from pytensor.tensor.var import TensorVariable
5654
from pytensor.updates import OrderedUpdates
@@ -63,11 +61,12 @@
6361
)
6462
from pymc.logprob.joint_logprob import factorized_joint_logprob
6563
from pymc.logprob.rewriting import (
66-
PreserveRVMappings,
64+
construct_ir_fgraph,
6765
inc_subtensor_ops,
6866
logprob_rewrites_db,
6967
measurable_ir_rewrites_db,
7068
)
69+
from pymc.pytensorf import replace_rvs_by_values
7170

7271

7372
class MeasurableScan(Scan):
@@ -249,9 +248,27 @@ def remove(x, i):
249248
new_inner_out_nit_sot = tuple(output_scan_args.inner_out_nit_sot) + tuple(
250249
inner_out_fn(remapped_io_to_ii)
251250
)
252-
253251
output_scan_args.inner_out_nit_sot = list(new_inner_out_nit_sot)
254252

253+
# Finally, we need to replace any lingering references to the new
254+
# internal variables that could be in the recurrent states needed
255+
# to compute the new nit_sots
256+
traced_outs = (
257+
output_scan_args.inner_out_mit_sot
258+
+ output_scan_args.inner_out_sit_sot
259+
+ output_scan_args.inner_out_nit_sot
260+
)
261+
traced_outs = replace_rvs_by_values(traced_outs, rvs_to_values=remapped_io_to_ii)
262+
# Update output mappings
263+
n_mit_sot = len(output_scan_args.inner_out_mit_sot)
264+
output_scan_args.inner_out_mit_sot = traced_outs[:n_mit_sot]
265+
offset = n_mit_sot
266+
n_sit_sot = len(output_scan_args.inner_out_sit_sot)
267+
output_scan_args.inner_out_sit_sot = traced_outs[offset : offset + n_sit_sot]
268+
offset += n_sit_sot
269+
n_nit_sot = len(output_scan_args.inner_out_nit_sot)
270+
output_scan_args.inner_out_nit_sot = traced_outs[offset : offset + n_nit_sot]
271+
255272
return output_scan_args
256273

257274

@@ -331,7 +348,10 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te
331348
for key, value in updates.items():
332349
key.default_update = value
333350

334-
return logp_scan_out
351+
# Return only the logp outputs, not any potentially carried states
352+
logp_outputs = logp_scan_out[-len(values) :]
353+
354+
return logp_outputs
335355

336356

337357
@node_rewriter([Scan])
@@ -504,19 +524,9 @@ def add_opts_to_inner_graphs(fgraph, node):
504524
if getattr(node.op.mode, "had_logprob_rewrites", False):
505525
return None
506526

507-
inner_fgraph = FunctionGraph(
508-
node.op.inner_inputs,
509-
node.op.inner_outputs,
510-
clone=True,
511-
copy_inputs=False,
512-
copy_orphans=False,
513-
features=[
514-
ShapeFeature(),
515-
PreserveRVMappings({}),
516-
],
517-
)
518-
519-
logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph)
527+
inner_rv_values = {out: out.type() for out in node.op.inner_outputs}
528+
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]))
529+
inner_fgraph, rv_values, _ = construct_ir_fgraph(inner_rv_values, ir_rewriter=ir_rewriter)
520530

521531
new_outputs = list(inner_fgraph.outputs)
522532

@@ -531,11 +541,23 @@ def add_opts_to_inner_graphs(fgraph, node):
531541

532542

533543
@_get_measurable_outputs.register(MeasurableScan)
534-
def _get_measurable_outputs_MeasurableScan(op, node):
535-
# TODO: This should probably use `get_random_outer_outputs`
536-
# scan_args = ScanArgs.from_node(node)
537-
# rv_outer_outs = get_random_outer_outputs(scan_args)
538-
return [o for o in node.outputs if not isinstance(o.type, RandomType)]
544+
def _get_measurable_outputs_MeasurableScan(op: Scan, node):
545+
"""Collect measurable outputs for Measurable Scans"""
546+
inner_out_from_outer_out_map = op.get_oinp_iinp_iout_oout_mappings()["inner_out_from_outer_out"]
547+
inner_outs = op.inner_outputs
548+
549+
# Measurable scan outputs are those whose inner scan output counterparts are also measurable
550+
measurable_outputs = []
551+
for out_idx, out in enumerate(node.outputs):
552+
[inner_out_idx] = inner_out_from_outer_out_map[out_idx]
553+
inner_out = inner_outs[inner_out_idx]
554+
inner_out_node = inner_out.owner
555+
if isinstance(
556+
inner_out_node.op, MeasurableVariable
557+
) and inner_out in get_measurable_outputs(inner_out_node.op, inner_out_node):
558+
measurable_outputs.append(out)
559+
560+
return measurable_outputs
539561

540562

541563
measurable_ir_rewrites_db.register(

tests/logprob/test_scan.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,55 @@ def test_scan_over_seqs():
502502
ys_logp.eval({xs_vv: xs_test, ys_vv: ys_test}),
503503
stats.norm.logpdf(ys_test, xs_test),
504504
)
505+
506+
507+
def test_scan_carried_deterministic_state():
508+
"""Test logp of scans with carried states downstream of measured variables.
509+
510+
A moving average model with 2 lags is used for testing.
511+
"""
512+
rng = np.random.default_rng(490)
513+
steps = 99
514+
515+
rho = pt.vector("rho", shape=(2,))
516+
sigma = pt.scalar("sigma")
517+
518+
def ma2_step(eps_tm2, eps_tm1, rho, sigma):
519+
mu = eps_tm1 * rho[0] + eps_tm2 * rho[1]
520+
y = pt.random.normal(mu, sigma)
521+
eps = y - mu
522+
update = {y.owner.inputs[0]: y.owner.outputs[0]}
523+
return (eps, y), update
524+
525+
[_, ma2], ma2_updates = pytensor.scan(
526+
fn=ma2_step,
527+
outputs_info=[{"initial": pt.arange(2, dtype="float64"), "taps": range(-2, 0)}, None],
528+
non_sequences=[rho, sigma],
529+
n_steps=steps,
530+
strict=True,
531+
name="ma2",
532+
)
533+
534+
def ref_logp(values, rho, sigma):
535+
epsilon_tm2 = 0
536+
epsilon_tm1 = 1
537+
step_logps = np.zeros_like(values)
538+
for t, value in enumerate(values):
539+
mu = epsilon_tm1 * rho[0] + epsilon_tm2 * rho[1]
540+
step_logps[t] = stats.norm.logpdf(value, mu, sigma)
541+
epsilon_tm2 = epsilon_tm1
542+
epsilon_tm1 = value - mu
543+
return step_logps
544+
545+
ma2_vv = ma2.clone()
546+
logp_expr = logp(ma2, ma2_vv)
547+
assert_no_rvs(logp_expr)
548+
549+
ma2_test = rng.normal(size=(steps,))
550+
rho_test = np.array([0.3, 0.7])
551+
sigma_test = 0.9
552+
553+
np.testing.assert_array_almost_equal(
554+
logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}),
555+
ref_logp(ma2_test, rho_test, sigma_test),
556+
)

0 commit comments

Comments
 (0)