Skip to content

Commit 4071c7d

Browse files
fonnesbecktwiecki
authored andcommitted
Fixed TestSoftUpdate
1 parent 02a9aef commit 4071c7d

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

pymc3/tests/test_sampling.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,30 +117,35 @@ def test_sample_tune_len(self):
117117
assert len(trace) == 100
118118

119119

120-
class SoftUpdate(SeededTest):
120+
class TestSoftUpdate(SeededTest):
121+
def setup_method(self):
122+
super(TestSoftUpdate, self).setup_method()
123+
121124
def test_soft_update_all_present(self):
122125
start = {'a': 1, 'b': 2}
123126
test_point = {'a': 3, 'b': 4}
124-
pm.sampling._soft_update(start, test_point)
127+
pm.sampling._update_start_vals(start, test_point, model=None)
125128
assert start == {'a': 1, 'b': 2}
126129

127130
def test_soft_update_one_missing(self):
128131
start = {'a': 1, }
129132
test_point = {'a': 3, 'b': 4}
130-
pm.sampling._soft_update(start, test_point)
133+
pm.sampling._update_start_vals(start, test_point, model=None)
131134
assert start == {'a': 1, 'b': 4}
132135

133136
def test_soft_update_empty(self):
134137
start = {}
135138
test_point = {'a': 3, 'b': 4}
136-
pm.sampling._soft_update(start, test_point)
139+
pm.sampling._update_start_vals(start, test_point, model=None)
137140
assert start == test_point
138141

139142
def test_soft_update_transformed(self):
140-
start = {'a': 2}
143+
with pm.Model() as model:
144+
pm.Exponential('a', 1)
145+
start = {'a': 2.}
141146
test_point = {'a_log__': 0}
142-
pm.sampling._soft_update(start, test_point)
143-
assert assert_almost_equal(start['a_log__'], np.log(start['a']))
147+
pm.sampling._update_start_vals(start, test_point, model)
148+
assert_almost_equal(np.exp(start['a_log__']), start['a'])
144149

145150

146151
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")

0 commit comments

Comments
 (0)