Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit fb6ffc5

Browse files
Explicitly remove AePPL IR from conditional_logprob results
1 parent c1ac578 commit fb6ffc5

File tree

6 files changed

+53
-11
lines changed

6 files changed

+53
-11
lines changed

aeppl/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class ValuedVariable(Op):
144144
directly. An example is `BroadcastTo` lifting through `RandomVariable`\s.
145145
"""
146146

147+
__props__ = ()
147148
default_output = 0
148149
view_map = {0: [0]}
149150

aeppl/joint_logprob.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from aeppl.abstract import ValuedVariable, get_measurable_outputs
1010
from aeppl.logprob import _logprob
11-
from aeppl.rewriting import construct_ir_fgraph
11+
from aeppl.rewriting import construct_ir_fgraph, ir_cleanup_db
1212

1313
if TYPE_CHECKING:
1414
from aesara.graph.basic import Apply, Variable
@@ -219,11 +219,9 @@ def conditional_logprob(
219219
# for node in io_toposort(graph_inputs([rv_logprobs]), outputs):
220220
# compute_test_value(node)
221221

222-
# Replace `ValuedVariable`s with their values
222+
# Remove unneeded IR elements from the graph
223223
rv_logprobs_fg = FunctionGraph(outputs=tuple(logprob_vars.values()), clone=False)
224-
rv_logprobs_fg.replace_all(
225-
tuple((valued_var, valued_var.owner.inputs[1]) for valued_var in fgraph.outputs)
226-
)
224+
ir_cleanup_db.query("+basic").rewrite(rv_logprobs_fg)
227225

228226
return logprob_vars, value_vars
229227

aeppl/rewriting.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aesara.graph.basic import Apply, Variable
66
from aesara.graph.features import Feature
77
from aesara.graph.fg import FunctionGraph
8-
from aesara.graph.rewriting.basic import GraphRewriter, node_rewriter
8+
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
99
from aesara.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB
1010
from aesara.tensor.elemwise import DimShuffle, Elemwise
1111
from aesara.tensor.extra_ops import BroadcastTo
@@ -277,3 +277,18 @@ def construct_ir_fgraph(
277277
fgraph.replace_all(new_to_old, reason="undo-unvalued-measurables")
278278

279279
return fgraph, rv_value_clones, memo
280+
281+
282+
@register_useless
283+
@node_rewriter([ValuedVariable])
284+
def remove_ValuedVariable(fgraph, node):
285+
return [node.inputs[1]]
286+
287+
288+
ir_cleanup_db = SequenceDB()
289+
ir_cleanup_db.name = "ir_cleanup_db"
290+
ir_cleanup_db.register(
291+
"remove-intermediate-ir",
292+
in2out(local_remove_DiracDelta, remove_ValuedVariable),
293+
"basic",
294+
)

aeppl/transforms.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
valued_variable,
2525
)
2626
from aeppl.logprob import _logprob, logprob
27-
from aeppl.rewriting import measurable_ir_rewrites_db
27+
from aeppl.rewriting import ir_cleanup_db, measurable_ir_rewrites_db
2828

2929
if TYPE_CHECKING:
3030
from aesara.graph.rewriting.basic import NodeRewriter
@@ -75,6 +75,7 @@ class TransformedVariable(Op):
7575
7676
"""
7777

78+
__props__ = ()
7879
view_map = {0: [0]}
7980

8081
def make_node(self, tran_value: TensorVariable, value: TensorVariable):
@@ -101,8 +102,12 @@ def grad(self, args, g_outs):
101102
@register_useless
102103
@node_rewriter([TransformedVariable])
103104
def remove_TransformedVariables(fgraph, node):
104-
if isinstance(node.op, TransformedVariable):
105-
return [node.inputs[0]]
105+
return [node.inputs[0]]
106+
107+
108+
ir_cleanup_db.register(
109+
"remove-TransformedVariables", in2out(remove_TransformedVariables), "basic"
110+
)
106111

107112

108113
class RVTransform(abc.ABC):

tests/test_abstract.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import aesara.tensor as at
33
import numpy as np
44
import pytest
5+
from aesara.compile.mode import get_default_mode
56
from aesara.gradient import NullTypeGradError, grad
67
from aesara.tensor.random.basic import NormalRV
78

@@ -113,7 +114,9 @@ def test_valued_variable():
113114
obs_var = valued_variable(rv_var, rv_vv)
114115

115116
rv_val = np.zeros(3)
116-
res = obs_var.eval({rv_var: rv_val, rv_vv: np.ones(3)})
117+
mode = get_default_mode().excluding("remove_ValuedVariable")
118+
obs_var_fn = aesara.function([rv_var, rv_vv], obs_var, mode=mode)
119+
res = obs_var_fn(rv_val, np.ones(3))
117120
assert np.array_equal(res, rv_val)
118121

119122
with pytest.raises(NullTypeGradError):

tests/test_joint_logprob.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
import scipy.stats.distributions as sp
6-
from aesara.graph.basic import ancestors, equal_computations
6+
from aesara.graph.basic import ancestors, applys_between, equal_computations
77
from aesara.tensor.subtensor import (
88
AdvancedIncSubtensor,
99
AdvancedIncSubtensor1,
@@ -13,6 +13,7 @@
1313
Subtensor,
1414
)
1515

16+
from aeppl.abstract import ValuedVariable
1617
from aeppl.joint_logprob import conditional_logprob, joint_logprob
1718
from aeppl.logprob import logprob
1819
from aeppl.utils import rvs_to_value_vars, walk_model
@@ -279,3 +280,22 @@ def test_deprecations():
279280

280281
with pytest.warns(DeprecationWarning):
281282
conditional_logprob(realized={X: x}, warn_missing_rvs=True)
283+
284+
285+
def test_no_output_ValuedVariables():
286+
srng = at.random.RandomStream(0)
287+
288+
X_at = at.matrix("X")
289+
tau_rv = srng.halfcauchy(1)
290+
beta_rv = srng.normal(0, tau_rv, size=X_at.shape[-1])
291+
292+
eta = X_at @ beta_rv
293+
p = at.sigmoid(-eta)
294+
Y_rv = srng.bernoulli(p)
295+
296+
logdensity, vvs = joint_logprob(Y_rv, beta_rv, tau_rv)
297+
298+
assert not any(
299+
isinstance(node.op, ValuedVariable)
300+
for node in applys_between(ins=vvs, outs=(logdensity,))
301+
)

0 commit comments

Comments
 (0)