@@ -117,30 +117,35 @@ def test_sample_tune_len(self):
117
117
assert len (trace ) == 100
118
118
119
119
120
- class SoftUpdate (SeededTest ):
120
+ class TestSoftUpdate (SeededTest ):
121
+ def setup_method (self ):
122
+ super (TestSoftUpdate , self ).setup_method ()
123
+
121
124
def test_soft_update_all_present (self ):
122
125
start = {'a' : 1 , 'b' : 2 }
123
126
test_point = {'a' : 3 , 'b' : 4 }
124
- pm .sampling ._soft_update (start , test_point )
127
+ pm .sampling ._update_start_vals (start , test_point , model = None )
125
128
assert start == {'a' : 1 , 'b' : 2 }
126
129
127
130
def test_soft_update_one_missing (self ):
128
131
start = {'a' : 1 , }
129
132
test_point = {'a' : 3 , 'b' : 4 }
130
- pm .sampling ._soft_update (start , test_point )
133
+ pm .sampling ._update_start_vals (start , test_point , model = None )
131
134
assert start == {'a' : 1 , 'b' : 4 }
132
135
133
136
def test_soft_update_empty (self ):
134
137
start = {}
135
138
test_point = {'a' : 3 , 'b' : 4 }
136
- pm .sampling ._soft_update (start , test_point )
139
+ pm .sampling ._update_start_vals (start , test_point , model = None )
137
140
assert start == test_point
138
141
139
142
def test_soft_update_transformed (self ):
140
- start = {'a' : 2 }
143
+ with pm .Model () as model :
144
+ pm .Exponential ('a' , 1 )
145
+ start = {'a' : 2. }
141
146
test_point = {'a_log__' : 0 }
142
- pm .sampling ._soft_update (start , test_point )
143
- assert assert_almost_equal (start ['a_log__' ], np . log ( start ['a' ]) )
147
+ pm .sampling ._update_start_vals (start , test_point , model )
148
+ assert_almost_equal (np . exp ( start ['a_log__' ]), start ['a' ])
144
149
145
150
146
151
@pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
0 commit comments