Skip to content

Commit f0390d4

Browse files
committed
Automatically retrieve updates from OpFromGraph nodes
1 parent 937e5fd commit f0390d4

File tree

4 files changed

+83
-41
lines changed

4 files changed

+83
-41
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pymc.printing import str_for_dist
6262
from pymc.pytensorf import (
6363
collect_default_updates,
64+
collect_default_updates_inner_fgraph,
6465
constant_fold,
6566
convert_observed_data,
6667
floatX,
@@ -300,14 +301,14 @@ def __init__(
300301
kwargs.setdefault("inline", True)
301302
super().__init__(*args, **kwargs)
302303

303-
def update(self, node: Node):
304+
def update(self, node: Node) -> dict[Variable, Variable]:
304305
"""Symbolic update expression for input random state variables
305306
306307
Returns a dictionary with the symbolic expressions required for correct updating
307308
of random state input variables repeated function evaluations. This is used by
308309
`pytensorf.compile_pymc`.
309310
"""
310-
return {}
311+
return collect_default_updates_inner_fgraph(node)
311312

312313
def batch_ndim(self, node: Node) -> int:
313314
"""Number of dimensions of the distribution's batch shape."""
@@ -705,20 +706,6 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
705706

706707
_print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}")
707708

708-
def update(self, node: Node):
709-
op = node.op
710-
inner_updates = collect_default_updates(
711-
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
712-
)
713-
714-
# Map inner updates to outer inputs/outputs
715-
updates = {}
716-
for rng, update in inner_updates.items():
717-
inp_idx = op.inner_inputs.index(rng)
718-
out_idx = op.inner_outputs.index(update)
719-
updates[node.inputs[inp_idx]] = node.outputs[out_idx]
720-
return updates
721-
722709

723710
@_support_point.register(CustomSymbolicDistRV)
724711
def dist_support_point(op, rv, *args):

pymc/pytensorf.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323

2424
from pytensor import scalar
2525
from pytensor.compile import Function, Mode, get_mode
26+
from pytensor.compile.builders import OpFromGraph
2627
from pytensor.gradient import grad
2728
from pytensor.graph import Type, rewrite_graph
2829
from pytensor.graph.basic import (
2930
Apply,
3031
Constant,
32+
Node,
3133
Variable,
3234
clone_get_equiv,
3335
graph_inputs,
@@ -781,6 +783,23 @@ def reseed_rngs(
781783
rng.set_value(new_rng, borrow=True)
782784

783785

786+
def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]:
787+
"""Collect default updates from node with inner fgraph."""
788+
op = node.op
789+
inner_updates = collect_default_updates(
790+
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
791+
)
792+
793+
# Map inner updates to outer inputs/outputs
794+
updates = {}
795+
for rng, update in inner_updates.items():
796+
inp_idx = op.inner_inputs.index(rng)
797+
out_idx = op.inner_outputs.index(update)
798+
updates[node.inputs[inp_idx]] = node.outputs[out_idx]
799+
800+
return updates
801+
802+
784803
def collect_default_updates(
785804
outputs: Sequence[Variable],
786805
*,
@@ -874,9 +893,16 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
874893
f"No update found for at least one RNG used in Scan Op {client.op}.\n"
875894
"You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically."
876895
)
896+
elif isinstance(client.op, OpFromGraph):
897+
try:
898+
next_rng = collect_default_updates_inner_fgraph(client)[rng]
899+
except (ValueError, KeyError):
900+
raise ValueError(
901+
f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n"
902+
"You can use `pytensorf.collect_default_updates` and include those updates as outputs."
903+
)
877904
else:
878-
# We don't know how this RNG should be updated (e.g., OpFromGraph).
879-
# The user should provide an update manually
905+
# We don't know how this RNG should be updated. The user should provide an update manually
880906
return None
881907

882908
# Recurse until we find final update for RNG

tests/distributions/test_distribution.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from pymc.exceptions import BlockModelAccessError
5353
from pymc.logprob.basic import conditional_logp, logcdf, logp
5454
from pymc.model import Deterministic, Model
55-
from pymc.pytensorf import collect_default_updates
55+
from pymc.pytensorf import collect_default_updates, compile_pymc
5656
from pymc.sampling import draw, sample
5757
from pymc.testing import (
5858
BaseTestDistributionRandom,
@@ -791,6 +791,41 @@ class TestInlinedSymbolicRV(SymbolicRandomVariable):
791791
x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)()
792792
assert np.isclose(logp(x_inline, 0).eval(), 0)
793793

794+
def test_default_update(self):
795+
"""Test SymbolicRandomVariable Op default to updates from inner graph."""
796+
797+
class SymbolicRVDefaultUpdates(SymbolicRandomVariable):
798+
pass
799+
800+
class SymbolicRVCustomUpdates(SymbolicRandomVariable):
801+
def update(self, node):
802+
return {}
803+
804+
rng = pytensor.shared(np.random.default_rng())
805+
dummy_rng = rng.type()
806+
dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs
807+
808+
# Check that default updates work
809+
next_rng, x = SymbolicRVDefaultUpdates(
810+
inputs=[dummy_rng],
811+
outputs=[dummy_next_rng, dummy_x],
812+
ndim_supp=0,
813+
)(rng)
814+
fn = compile_pymc(inputs=[], outputs=x, random_seed=431)
815+
assert fn() != fn()
816+
817+
# Check that custom updates are respected, by using one that's broken
818+
next_rng, x = SymbolicRVCustomUpdates(
819+
inputs=[dummy_rng],
820+
outputs=[dummy_next_rng, dummy_x],
821+
ndim_supp=0,
822+
)(rng)
823+
with pytest.raises(
824+
ValueError,
825+
match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates",
826+
):
827+
compile_pymc(inputs=[], outputs=x, random_seed=431)
828+
794829

795830
def test_tag_future_warning_dist():
796831
# Test no unexpected warnings

tests/test_pytensorf.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -408,28 +408,6 @@ def test_compile_pymc_updates_inputs(self):
408408
# Each RV adds a shared output for its rng
409409
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph
410410

411-
def test_compile_pymc_symbolic_rv_update(self):
412-
"""Test that SymbolicRandomVariable Op update methods are used by compile_pymc"""
413-
414-
class NonSymbolicRV(OpFromGraph):
415-
def update(self, node):
416-
return {node.inputs[0]: node.outputs[0]}
417-
418-
rng = pytensor.shared(np.random.default_rng())
419-
dummy_rng = rng.type()
420-
dummy_next_rng, dummy_x = NonSymbolicRV(
421-
[dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs
422-
)(rng)
423-
424-
# Check that there are no updates at first
425-
fn = compile_pymc(inputs=[], outputs=dummy_x)
426-
assert fn() == fn()
427-
428-
# And they are enabled once the Op is registered as a SymbolicRV
429-
SymbolicRandomVariable.register(NonSymbolicRV)
430-
fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431)
431-
assert fn() != fn()
432-
433411
def test_compile_pymc_symbolic_rv_missing_update(self):
434412
"""Test that error is raised if SymbolicRandomVariable Op does not
435413
provide rule for updating RNG"""
@@ -588,6 +566,22 @@ def step_wo_update(x, rng):
588566
fn = compile_pymc([], ys, random_seed=1)
589567
assert not (set(fn()) & set(fn()))
590568

569+
def test_op_from_graph_updates(self):
570+
rng = pytensor.shared(np.random.default_rng())
571+
next_rng_, x_ = pt.random.normal(size=(10,), rng=rng).owner.outputs
572+
573+
x = OpFromGraph([], [x_])()
574+
with pytest.raises(
575+
ValueError,
576+
match="No update found for at least one RNG used in OpFromGraph Op",
577+
):
578+
collect_default_updates([x])
579+
580+
next_rng, x = OpFromGraph([], [next_rng_, x_])()
581+
assert collect_default_updates([x]) == {rng: next_rng}
582+
fn = compile_pymc([], x, random_seed=1)
583+
assert not (set(fn()) & set(fn()))
584+
591585

592586
def test_replace_rng_nodes():
593587
rng = pytensor.shared(np.random.default_rng())

0 commit comments

Comments
 (0)