Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d19ef6e

Browse files
committed
Create value variables when building the logprob graphs
1 parent f43e868 commit d19ef6e

File tree

11 files changed

+201
-314
lines changed

11 files changed

+201
-314
lines changed

aeppl/joint_logprob.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from collections import deque
3-
from typing import Dict, Optional, Union
3+
from typing import Dict, List, Optional, Tuple, Union
44

55
import aesara.tensor as at
66
from aesara import config
@@ -16,26 +16,27 @@
1616

1717

1818
def conditional_logprob(
19-
rv_values: Dict[TensorVariable, TensorVariable],
19+
*random_variables: TensorVariable,
20+
realized: Dict[TensorVariable, TensorVariable] = {},
2021
warn_missing_rvs: bool = True,
2122
ir_rewriter: Optional[GraphRewriter] = None,
2223
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
2324
**kwargs,
24-
) -> Dict[TensorVariable, TensorVariable]:
25-
r"""Create a map between variables and their conditional log-probabilities.
25+
) -> Tuple[Dict[TensorVariable, TensorVariable], List[TensorVariable]]:
26+
r"""Create a map between random variables and their conditional log-probabilities.
2627
27-
The `rvs` list implicitly defines a joint probability, that factorizes
28-
according to the graphical model represented by the Aesara model the
29-
`RandomVariable`s belong to.
28+
The list of measurable variables implicitly defines a joint probability that
29+
factorizes according to the graphical model implemented by the Aesara model
30+
these variables belong to.
3031
3132
For example, consider the following
3233
3334
.. code-block:: python
3435
3536
import aesara.tensor as at
3637
37-
sigma2_rv = at.random.invgamma(0.5, 0.5)
38-
Y_rv = at.random.normal(0, at.sqrt(sigma2_rv))
38+
sigma2_rv = at.random.invgamma(0.5, 0.5, name="sigma2")
39+
Y_rv = at.random.normal(0, at.sqrt(sigma2_rv), name="Y")
3940
4041
This graph for ``Y_rv`` is equivalent to the following hierarchical model:
4142
@@ -51,9 +52,9 @@ def conditional_logprob(
5152
:math:`Y`'s respective conditional log-probabilities, :math:`\log P(\sigma^2 = s)`
5253
and :math:`\log p(Y = y | \sigma^2 = s)`.
5354
54-
`conditional_logprob` generates the value variables that correspond to the
55-
measurable variables for which it produces a conditional log-probability
56-
graph and returns them along with the graphs:
55+
To build the log-probability graphs, `conditional_logprob` must generate
56+
value variables associated with each input variable. They are returned along
57+
with the graphs:
5758
5859
.. code-block:: python
5960
@@ -62,16 +63,18 @@ def conditional_logprob(
6263
sigma2_rv = at.random.invgamma(0.5, 0.5)
6364
Y_rv = at.random.normal(0, at.sqrt(sigma2_rv))
6465
65-
logprobs = conditional_logprob(Y_rv, sigma2_rv)
66+
logprobs, value_variables = conditional_logprob(Y_rv, sigma2_rv)
6667
# print(logprobs.keys())
67-
# [sigma2_vv, Y_vv]
68+
# [Y, sigma2]
69+
# print(value_variables)
70+
# [Y_vv, sigma2_vv]
6871
6972
7073
Parameters
7174
==========
72-
rv_values
73-
A ``dict`` that maps measurable variables (e.g. `RandomVariable`s) to
74-
symbolic `Variable`\s that represent their values.
75+
random_variables
76+
A ``list`` of random variables for which we need to return a
77+
conditional log-probability graph.
7578
warn_missing_rvs
7679
When ``True``, issue a warning when a `RandomVariable` is found in
7780
the graph and doesn't have a corresponding value variable specified in
@@ -84,11 +87,27 @@ def conditional_logprob(
8487
8588
Returns
8689
=======
87-
A ``dict`` that maps each value variable to the log-probability factor derived
88-
from the respective `RandomVariable`.
90+
A ``dict`` that maps each random variable to the derived log-probability
91+
factor, and a list of the created valued variables in the same order as the
92+
order in which their corresponding random variables were passed as
93+
arguments.
8994
9095
"""
96+
97+
# Create value variables by cloning the input measurable variables
98+
original_rv_values = {}
99+
for rv in random_variables:
100+
vv = rv.clone()
101+
if rv.name:
102+
vv.name = f"{rv.name}_vv"
103+
original_rv_values[rv] = vv
104+
105+
# Value variables are not cloned when constructing the conditional log-proprobability
106+
# graphs. We can thus use them to recover the original random variables to index the
107+
# maps to the logprob graphs and value variables before returning them.
108+
rv_values = {**original_rv_values, **realized}
91109
vv_to_original_rvs = {vv: rv for rv, vv in rv_values.items()}
110+
92111
fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
93112

94113
if extra_rewrites is not None:
@@ -194,30 +213,45 @@ def conditional_logprob(
194213
f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
195214
)
196215

197-
return logprob_vars
216+
return logprob_vars, list(original_rv_values.values())
198217

199218

200219
def joint_logprob(
201-
rv_values: Dict[TensorVariable, TensorVariable], *args, **kwargs
202-
) -> Optional[TensorVariable]:
220+
*random_variables: List[TensorVariable],
221+
realized: Dict[TensorVariable, TensorVariable] = {},
222+
**kwargs,
223+
) -> Optional[Tuple[TensorVariable, List[TensorVariable]]]:
203224
"""Create a graph representing the joint log-probability/measure of a graph.
204225
205226
This function calls `factorized_joint_logprob` and returns the combined
206227
log-probability factors as a single graph.
207228
208229
Parameters
209-
----------
210-
sum: bool
211-
If ``True`` each factor is collapsed to a scalar via ``sum`` before
212-
being joined with the remaining factors. This may be necessary to
213-
avoid incorrect broadcasting among independent factors.
230+
==========
231+
random_variables
232+
A ``list`` of random variables for which we need to return a
233+
conditional log-probability graph.
234+
realized
235+
A ``dict`` that maps random variables to their realized value.
236+
237+
Returns
238+
=======
239+
A ``TensorVariable`` that represents the joint log-probability of the graph
240+
implicitly defined by the random variables passed as arguments, and a list
241+
of the created valued variables in the same order as the order in which
242+
their corresponding random variables were passed as arguments.
214243
215244
"""
216-
logprob = conditional_logprob(rv_values, *args, **kwargs)
245+
logprob, value_variables = conditional_logprob(
246+
*random_variables, realized=realized, **kwargs
247+
)
217248
if not logprob:
218249
return None
219250
elif len(logprob) == 1:
220-
logprob = tuple(logprob.values())[0]
221-
return at.sum(logprob)
251+
cond_logprob = tuple(logprob.values())[0]
252+
return at.sum(cond_logprob), value_variables
222253
else:
223-
return at.sum([at.sum(factor) for factor in logprob.values()])
254+
joint_logprob: TensorVariable = at.sum(
255+
[at.sum(factor) for factor in logprob.values()]
256+
)
257+
return joint_logprob, value_variables

aeppl/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def create_inner_out_logp(
268268
value_map: Dict[TensorVariable, TensorVariable]
269269
) -> TensorVariable:
270270
"""Create a log-likelihood inner-output for a `Scan`."""
271-
logp_parts = conditional_logprob(value_map, warn_missing_rvs=False)
271+
logp_parts, _ = conditional_logprob(realized=value_map, warn_missing_rvs=False)
272272
return logp_parts.values()
273273

274274
logp_scan_args = convert_outer_out_to_in(

tests/test_censoring.py

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,10 @@ def test_continuous_rv_clip():
1515
x_rv = at.random.normal(0.5, 1)
1616
cens_x_rv = at.clip(x_rv, -2, 2)
1717

18-
cens_x_vv = cens_x_rv.clone()
19-
cens_x_vv.tag.test_value = 0
20-
21-
logp = joint_logprob({cens_x_rv: cens_x_vv})
18+
logp, vv = joint_logprob(cens_x_rv)
2219
assert_no_rvs(logp)
2320

24-
logp_fn = aesara.function([cens_x_vv], logp)
21+
logp_fn = aesara.function(vv, logp)
2522
ref_scipy = st.norm(0.5, 1)
2623

2724
assert logp_fn(-3) == -np.inf
@@ -36,12 +33,10 @@ def test_discrete_rv_clip():
3633
x_rv = at.random.poisson(2)
3734
cens_x_rv = at.clip(x_rv, 1, 4)
3835

39-
cens_x_vv = cens_x_rv.clone()
40-
41-
logp = joint_logprob({cens_x_rv: cens_x_vv})
36+
logp, vv = joint_logprob(cens_x_rv)
4237
assert_no_rvs(logp)
4338

44-
logp_fn = aesara.function([cens_x_vv], logp)
39+
logp_fn = aesara.function(vv, logp)
4540
ref_scipy = st.poisson(2)
4641

4742
assert logp_fn(0) == -np.inf
@@ -57,11 +52,8 @@ def test_one_sided_clip():
5752
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
5853
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)
5954

60-
lb_cens_x_vv = lb_cens_x_rv.clone()
61-
ub_cens_x_vv = ub_cens_x_rv.clone()
62-
63-
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv})
64-
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv})
55+
lb_logp, (lb_cens_x_vv,) = joint_logprob(lb_cens_x_rv)
56+
ub_logp, (ub_cens_x_vv,) = joint_logprob(ub_cens_x_rv)
6557
assert_no_rvs(lb_logp)
6658
assert_no_rvs(ub_logp)
6759

@@ -78,9 +70,8 @@ def test_useless_clip():
7870
x_rv = at.random.normal(0.5, 1, size=3)
7971
cens_x_rv = at.clip(x_rv, x_rv, x_rv)
8072

81-
cens_x_vv = cens_x_rv.clone()
82-
83-
logp = conditional_logprob({cens_x_rv: cens_x_vv})[cens_x_rv]
73+
logps, (cens_x_vv,) = conditional_logprob(cens_x_rv)
74+
logp = logps[cens_x_rv]
8475
assert_no_rvs(logp)
8576

8677
logp_fn = aesara.function([cens_x_vv], logp)
@@ -94,9 +85,7 @@ def test_random_clip():
9485
x_rv = at.random.normal(0, 2)
9586
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
9687

97-
lb_vv = lb_rv.clone()
98-
cens_x_vv = cens_x_rv.clone()
99-
logps = conditional_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
88+
logps, (cens_x_vv, lb_vv) = conditional_logprob(cens_x_rv, lb_rv)
10089
logp = at.add(*logps.values())
10190
assert_no_rvs(logp)
10291

@@ -111,10 +100,7 @@ def test_broadcasted_clip_constant():
111100
x_rv = at.random.normal(0, 2)
112101
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
113102

114-
lb_vv = lb_rv.clone()
115-
cens_x_vv = cens_x_rv.clone()
116-
117-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
103+
logp, _ = joint_logprob(cens_x_rv, lb_rv)
118104
assert_no_rvs(logp)
119105

120106

@@ -123,10 +109,7 @@ def test_broadcasted_clip_random():
123109
x_rv = at.random.normal(0, 2, size=2)
124110
cens_x_rv = at.clip(x_rv, lb_rv, 1)
125111

126-
lb_vv = lb_rv.clone()
127-
cens_x_vv = cens_x_rv.clone()
128-
129-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
112+
logp, _ = joint_logprob(cens_x_rv, lb_rv)
130113
assert_no_rvs(logp)
131114

132115

@@ -136,10 +119,8 @@ def test_fail_base_and_clip_have_values():
136119
cens_x_rv = at.clip(x_rv, x_rv, 1)
137120
cens_x_rv.name = "cens_x"
138121

139-
x_vv = x_rv.clone()
140-
cens_x_vv = cens_x_rv.clone()
141122
with pytest.raises(RuntimeError, match="could not be derived: {cens_x}"):
142-
conditional_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})
123+
conditional_logprob(cens_x_rv, x_rv)
143124

144125

145126
def test_fail_multiple_clip_single_base():
@@ -150,20 +131,16 @@ def test_fail_multiple_clip_single_base():
150131
cens_rv2 = at.clip(base_rv, -1, 1)
151132
cens_rv2.name = "cens2"
152133

153-
cens_vv1 = cens_rv1.clone()
154-
cens_vv2 = cens_rv2.clone()
155134
with pytest.raises(RuntimeError, match="could not be derived: {cens2}"):
156-
conditional_logprob({cens_rv1: cens_vv1, cens_rv2: cens_vv2})
135+
conditional_logprob(cens_rv1, cens_rv2)
157136

158137

159138
def test_deterministic_clipping():
160139
x_rv = at.random.normal(0, 1)
161140
clip = at.clip(x_rv, 0, 0)
162141
y_rv = at.random.normal(clip, 1)
163142

164-
x_vv = x_rv.clone()
165-
y_vv = y_rv.clone()
166-
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
143+
logp, (x_vv, y_vv) = joint_logprob(x_rv, y_rv)
167144
assert_no_rvs(logp)
168145

169146
logp_fn = aesara.function([x_vv, y_vv], logp)
@@ -180,7 +157,7 @@ def test_clip_transform():
180157
cens_x_vv = cens_x_rv.clone()
181158

182159
transform = TransformValuesRewrite({cens_x_vv: LogTransform()})
183-
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
160+
logp, _ = joint_logprob(realized={cens_x_rv: cens_x_vv}, extra_rewrites=transform)
184161

185162
cens_x_vv_testval = -1
186163
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
@@ -201,8 +178,8 @@ def test_rounding(rounding_op):
201178
xr = rounding_op(x)
202179
xr.name = "xr"
203180

204-
xr_vv = xr.clone()
205-
logp = conditional_logprob({xr: xr_vv})[xr]
181+
logp, (xr_vv,) = conditional_logprob(xr)
182+
logp = logp[xr]
206183
assert logp is not None
207184

208185
x_sp = st.norm(loc, scale)

0 commit comments

Comments
 (0)