55
55
56
56
import pymc as pm
57
57
58
- from pymc .logprob .basic import factorized_joint_logprob , icdf , joint_logp , logcdf , logp
58
+ from pymc .logprob .basic import (
59
+ conditional_logp ,
60
+ icdf ,
61
+ logcdf ,
62
+ logp ,
63
+ transformed_conditional_logp ,
64
+ )
59
65
from pymc .logprob .transforms import LogTransform
60
66
from pymc .logprob .utils import rvs_to_value_vars , walk_model
61
67
from pymc .pytensorf import replace_rvs_by_values
@@ -68,7 +74,7 @@ def test_factorized_joint_logprob_basic():
68
74
a .name = "a"
69
75
a_value_var = a .clone ()
70
76
71
- a_logp = factorized_joint_logprob ({a : a_value_var })
77
+ a_logp = conditional_logp ({a : a_value_var })
72
78
a_logp_comb = tuple (a_logp .values ())[0 ]
73
79
a_logp_exp = logp (a , a_value_var )
74
80
@@ -81,7 +87,7 @@ def test_factorized_joint_logprob_basic():
81
87
sigma_value_var = sigma .clone ()
82
88
y_value_var = Y .clone ()
83
89
84
- total_ll = factorized_joint_logprob ({Y : y_value_var , sigma : sigma_value_var })
90
+ total_ll = conditional_logp ({Y : y_value_var , sigma : sigma_value_var })
85
91
total_ll_combined = pt .add (* total_ll .values ())
86
92
87
93
# We need to replace the reference to `sigma` in `Y` with its value
@@ -106,7 +112,7 @@ def test_factorized_joint_logprob_basic():
106
112
b_value_var = b .clone ()
107
113
c_value_var = c .clone ()
108
114
109
- b_logp = factorized_joint_logprob ({a : a_value_var , b : b_value_var , c : c_value_var })
115
+ b_logp = conditional_logp ({a : a_value_var , b : b_value_var , c : c_value_var })
110
116
b_logp_combined = pt .sum ([pt .sum (factor ) for factor in b_logp .values ()])
111
117
112
118
# There shouldn't be any `RandomVariable`s in the resulting graph
@@ -125,7 +131,7 @@ def test_factorized_joint_logprob_multi_obs():
125
131
a_val = a .clone ()
126
132
b_val = b .clone ()
127
133
128
- logp_res = factorized_joint_logprob ({a : a_val , b : b_val })
134
+ logp_res = conditional_logp ({a : a_val , b : b_val })
129
135
logp_res_combined = pt .add (* logp_res .values ())
130
136
logp_exp = logp (a , a_val ) + logp (b , b_val )
131
137
@@ -137,8 +143,8 @@ def test_factorized_joint_logprob_multi_obs():
137
143
x_val = x .clone ()
138
144
y_val = y .clone ()
139
145
140
- logp_res = factorized_joint_logprob ({x : x_val , y : y_val })
141
- exp_logp = factorized_joint_logprob ({x : x_val , y : y_val })
146
+ logp_res = conditional_logp ({x : x_val , y : y_val })
147
+ exp_logp = conditional_logp ({x : x_val , y : y_val })
142
148
logp_res_comb = pt .sum ([pt .sum (factor ) for factor in logp_res .values ()])
143
149
exp_logp_comb = pt .sum ([pt .sum (factor ) for factor in exp_logp .values ()])
144
150
@@ -155,7 +161,7 @@ def test_factorized_joint_logprob_diff_dims():
155
161
y_vv = y .clone ()
156
162
y_vv .name = "y"
157
163
158
- logp = factorized_joint_logprob ({x : x_vv , y : y_vv })
164
+ logp = conditional_logp ({x : x_vv , y : y_vv })
159
165
logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
160
166
161
167
M_val = np .random .normal (size = (10 , 3 ))
@@ -181,7 +187,7 @@ def test_incsubtensor_original_values_output_dict():
181
187
rv = pt .set_subtensor (base_rv [0 ], 5 )
182
188
vv = rv .clone ()
183
189
184
- logp_dict = factorized_joint_logprob ({rv : vv })
190
+ logp_dict = conditional_logp ({rv : vv })
185
191
assert vv in logp_dict
186
192
187
193
@@ -194,14 +200,14 @@ def test_persist_inputs():
194
200
beta_vv = beta_rv .type ()
195
201
y_vv = Y_rv .clone ()
196
202
197
- logp = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv })
203
+ logp = conditional_logp ({beta_rv : beta_vv , Y_rv : y_vv })
198
204
logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
199
205
200
206
assert x in ancestors ([logp_combined ])
201
207
202
208
# Make sure we don't clone value variables when they're graphs.
203
209
y_vv_2 = y_vv * 2
204
- logp_2 = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
210
+ logp_2 = conditional_logp ({beta_rv : beta_vv , Y_rv : y_vv_2 })
205
211
logp_2_combined = pt .sum ([pt .sum (factor ) for factor in logp_2 .values ()])
206
212
207
213
assert y_vv in ancestors ([logp_2_combined ])
@@ -210,7 +216,7 @@ def test_persist_inputs():
210
216
# Even when they are random
211
217
y_vv = pt .random .normal (name = "y_vv2" )
212
218
y_vv_2 = y_vv * 2
213
- logp_2 = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
219
+ logp_2 = conditional_logp ({beta_rv : beta_vv , Y_rv : y_vv_2 })
214
220
logp_2_combined = pt .sum ([pt .sum (factor ) for factor in logp_2 .values ()])
215
221
216
222
assert y_vv in ancestors ([logp_2_combined ])
@@ -224,11 +230,11 @@ def test_warn_random_found_factorized_joint_logprob():
224
230
y_vv = y_rv .clone ()
225
231
226
232
with pytest .warns (UserWarning , match = "Random variables detected in the logp graph: {x}" ):
227
- factorized_joint_logprob ({y_rv : y_vv })
233
+ conditional_logp ({y_rv : y_vv })
228
234
229
235
with warnings .catch_warnings ():
230
236
warnings .simplefilter ("error" )
231
- factorized_joint_logprob ({y_rv : y_vv }, warn_missing_rvs = False )
237
+ conditional_logp ({y_rv : y_vv }, warn_missing_rvs = False )
232
238
233
239
234
240
def test_multiple_rvs_to_same_value_raises ():
@@ -237,9 +243,9 @@ def test_multiple_rvs_to_same_value_raises():
237
243
x = x_rv1 .type ()
238
244
x .name = "x"
239
245
240
- msg = "More than one logprob factor was assigned to the value var x"
246
+ msg = "More than one logprob term was assigned to the value var x"
241
247
with pytest .raises (ValueError , match = msg ):
242
- factorized_joint_logprob ({x_rv1 : x , x_rv2 : x })
248
+ conditional_logp ({x_rv1 : x , x_rv2 : x })
243
249
244
250
245
251
def test_joint_logp_basic ():
@@ -259,7 +265,7 @@ def test_joint_logp_basic():
259
265
260
266
c_value_var = m .rvs_to_values [c ]
261
267
262
- (b_logp ,) = joint_logp (
268
+ (b_logp ,) = transformed_conditional_logp (
263
269
(b ,),
264
270
rvs_to_values = m .rvs_to_values ,
265
271
rvs_to_transforms = m .rvs_to_transforms ,
@@ -304,7 +310,7 @@ def test_joint_logp_incsubtensor(indices, size):
304
310
a_idx_value_var = a_idx .type ()
305
311
a_idx_value_var .name = "a_idx_value"
306
312
307
- a_idx_logp = joint_logp (
313
+ a_idx_logp = transformed_conditional_logp (
308
314
(a_idx ,),
309
315
rvs_to_values = {a_idx : a_value_var },
310
316
rvs_to_transforms = {},
0 commit comments