@@ -15,13 +15,10 @@ def test_continuous_rv_clip():
1515 x_rv = at .random .normal (0.5 , 1 )
1616 cens_x_rv = at .clip (x_rv , - 2 , 2 )
1717
18- cens_x_vv = cens_x_rv .clone ()
19- cens_x_vv .tag .test_value = 0
20-
21- logp = joint_logprob ({cens_x_rv : cens_x_vv })
18+ logp , vv = joint_logprob (cens_x_rv )
2219 assert_no_rvs (logp )
2320
24- logp_fn = aesara .function ([ cens_x_vv ] , logp )
21+ logp_fn = aesara .function (vv , logp )
2522 ref_scipy = st .norm (0.5 , 1 )
2623
2724 assert logp_fn (- 3 ) == - np .inf
@@ -36,12 +33,10 @@ def test_discrete_rv_clip():
3633 x_rv = at .random .poisson (2 )
3734 cens_x_rv = at .clip (x_rv , 1 , 4 )
3835
39- cens_x_vv = cens_x_rv .clone ()
40-
41- logp = joint_logprob ({cens_x_rv : cens_x_vv })
36+ logp , vv = joint_logprob (cens_x_rv )
4237 assert_no_rvs (logp )
4338
44- logp_fn = aesara .function ([ cens_x_vv ] , logp )
39+ logp_fn = aesara .function (vv , logp )
4540 ref_scipy = st .poisson (2 )
4641
4742 assert logp_fn (0 ) == - np .inf
@@ -57,11 +52,8 @@ def test_one_sided_clip():
5752 lb_cens_x_rv = at .clip (x_rv , - 1 , x_rv )
5853 ub_cens_x_rv = at .clip (x_rv , x_rv , 1 )
5954
60- lb_cens_x_vv = lb_cens_x_rv .clone ()
61- ub_cens_x_vv = ub_cens_x_rv .clone ()
62-
63- lb_logp = joint_logprob ({lb_cens_x_rv : lb_cens_x_vv })
64- ub_logp = joint_logprob ({ub_cens_x_rv : ub_cens_x_vv })
55+ lb_logp , (lb_cens_x_vv ,) = joint_logprob (lb_cens_x_rv )
56+ ub_logp , (ub_cens_x_vv ,) = joint_logprob (ub_cens_x_rv )
6557 assert_no_rvs (lb_logp )
6658 assert_no_rvs (ub_logp )
6759
@@ -78,9 +70,8 @@ def test_useless_clip():
7870 x_rv = at .random .normal (0.5 , 1 , size = 3 )
7971 cens_x_rv = at .clip (x_rv , x_rv , x_rv )
8072
81- cens_x_vv = cens_x_rv .clone ()
82-
83- logp = conditional_logprob ({cens_x_rv : cens_x_vv })[cens_x_rv ]
73+ logps , (cens_x_vv ,) = conditional_logprob (cens_x_rv )
74+ logp = logps [cens_x_rv ]
8475 assert_no_rvs (logp )
8576
8677 logp_fn = aesara .function ([cens_x_vv ], logp )
@@ -94,9 +85,7 @@ def test_random_clip():
9485 x_rv = at .random .normal (0 , 2 )
9586 cens_x_rv = at .clip (x_rv , lb_rv , [1 , 1 ])
9687
97- lb_vv = lb_rv .clone ()
98- cens_x_vv = cens_x_rv .clone ()
99- logps = conditional_logprob ({cens_x_rv : cens_x_vv , lb_rv : lb_vv })
88+ logps , (cens_x_vv , lb_vv ) = conditional_logprob (cens_x_rv , lb_rv )
10089 logp = at .add (* logps .values ())
10190 assert_no_rvs (logp )
10291
@@ -111,10 +100,7 @@ def test_broadcasted_clip_constant():
111100 x_rv = at .random .normal (0 , 2 )
112101 cens_x_rv = at .clip (x_rv , lb_rv , [1 , 1 ])
113102
114- lb_vv = lb_rv .clone ()
115- cens_x_vv = cens_x_rv .clone ()
116-
117- logp = joint_logprob ({cens_x_rv : cens_x_vv , lb_rv : lb_vv })
103+ logp , _ = joint_logprob (cens_x_rv , lb_rv )
118104 assert_no_rvs (logp )
119105
120106
@@ -123,10 +109,7 @@ def test_broadcasted_clip_random():
123109 x_rv = at .random .normal (0 , 2 , size = 2 )
124110 cens_x_rv = at .clip (x_rv , lb_rv , 1 )
125111
126- lb_vv = lb_rv .clone ()
127- cens_x_vv = cens_x_rv .clone ()
128-
129- logp = joint_logprob ({cens_x_rv : cens_x_vv , lb_rv : lb_vv })
112+ logp , _ = joint_logprob (cens_x_rv , lb_rv )
130113 assert_no_rvs (logp )
131114
132115
@@ -136,10 +119,8 @@ def test_fail_base_and_clip_have_values():
136119 cens_x_rv = at .clip (x_rv , x_rv , 1 )
137120 cens_x_rv .name = "cens_x"
138121
139- x_vv = x_rv .clone ()
140- cens_x_vv = cens_x_rv .clone ()
141122 with pytest .raises (RuntimeError , match = "could not be derived: {cens_x}" ):
142- conditional_logprob ({ cens_x_rv : cens_x_vv , x_rv : x_vv } )
123+ conditional_logprob (cens_x_rv , x_rv )
143124
144125
145126def test_fail_multiple_clip_single_base ():
@@ -150,20 +131,16 @@ def test_fail_multiple_clip_single_base():
150131 cens_rv2 = at .clip (base_rv , - 1 , 1 )
151132 cens_rv2 .name = "cens2"
152133
153- cens_vv1 = cens_rv1 .clone ()
154- cens_vv2 = cens_rv2 .clone ()
155134 with pytest .raises (RuntimeError , match = "could not be derived: {cens2}" ):
156- conditional_logprob ({ cens_rv1 : cens_vv1 , cens_rv2 : cens_vv2 } )
135+ conditional_logprob (cens_rv1 , cens_rv2 )
157136
158137
159138def test_deterministic_clipping ():
160139 x_rv = at .random .normal (0 , 1 )
161140 clip = at .clip (x_rv , 0 , 0 )
162141 y_rv = at .random .normal (clip , 1 )
163142
164- x_vv = x_rv .clone ()
165- y_vv = y_rv .clone ()
166- logp = joint_logprob ({x_rv : x_vv , y_rv : y_vv })
143+ logp , (x_vv , y_vv ) = joint_logprob (x_rv , y_rv )
167144 assert_no_rvs (logp )
168145
169146 logp_fn = aesara .function ([x_vv , y_vv ], logp )
@@ -180,7 +157,7 @@ def test_clip_transform():
180157 cens_x_vv = cens_x_rv .clone ()
181158
182159 transform = TransformValuesRewrite ({cens_x_vv : LogTransform ()})
183- logp = joint_logprob ({cens_x_rv : cens_x_vv }, extra_rewrites = transform )
160+ logp , _ = joint_logprob (realized = {cens_x_rv : cens_x_vv }, extra_rewrites = transform )
184161
185162 cens_x_vv_testval = - 1
186163 obs_logp = logp .eval ({cens_x_vv : cens_x_vv_testval })
@@ -201,8 +178,8 @@ def test_rounding(rounding_op):
201178 xr = rounding_op (x )
202179 xr .name = "xr"
203180
204- xr_vv = xr . clone ( )
205- logp = conditional_logprob ({ xr : xr_vv }) [xr ]
181+ logp , ( xr_vv ,) = conditional_logprob ( xr )
182+ logp = logp [xr ]
206183 assert logp is not None
207184
208185 x_sp = st .norm (loc , scale )
0 commit comments