Skip to content

Commit 23e2d59

Browse files
committed
fixed up tests
1 parent b9405ba commit 23e2d59

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

tests/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def simple_init():
66
start, model, moments = simple_model()
77

88
step = Metropolis(model, model.vars, np.diag([1.]))
9-
return start, step, moments
9+
return model, start, step, moments
1010

1111

1212
def simple_model():
@@ -22,8 +22,8 @@ def simple_2model():
2222
tau = 1.3
2323
p = .4
2424
with Model() as model:
25-
x = Normal('x', mu,tau, testval = .1)
26-
y = Bernoulli('y', p)
25+
x = pm.Normal('x', mu,tau, testval = .1)
26+
y = pm.Bernoulli('y', p)
2727

2828
return model.test_point, model
2929

tests/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
test_parallel = False
1111

1212
def test_sample():
13-
start, step,_ = simple_init()
13+
model, start, step,_ = simple_init()
1414

1515
test_samplers = [sample]
1616
if test_parallel:
1717
test_samplers.append(psample)
1818

1919
for samplr in test_samplers:
2020
for n in [0, 10, 1000]:
21-
yield samplr, n, step, start
21+
yield samplr, model, n, step, start
2222

2323

2424

tests/test_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_step_continuous():
2626
for st in steps:
2727
for (var, stat, val, bound) in check:
2828
np.random.seed(1)
29-
h, _, _ = sample(8000, st, start)
29+
h = sample(model, 8000, st, start)
3030

3131
yield check_stat,repr(st), h, var, stat, val, bound
3232

tests/test_trace.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,44 +10,44 @@
1010
except:
1111
test_parallel = False
1212

13-
def check_trace(trace, n, step, start):
14-
13+
def check_trace(model, trace, n, step, start):
1514
#try using a trace object a few times
1615
for i in range(2):
17-
trace, _, _ = sample(n, step, start, trace)
16+
trace = sample(model, n, step, start, trace)
1817

1918
for (var, val) in start.iteritems():
2019

2120
assert np.shape(trace[var]) == (n*(i+1),) + np.shape(val)
2221

2322

2423
def test_trace():
25-
start, step,_ = simple_init()
24+
model, start, step,_ = simple_init()
2625

2726
for h in [pm.NpTrace]:
2827
for n in [20, 1000]:
29-
trace = h()
28+
trace = h(model.vars)
3029

31-
yield check_trace, trace, n, step, start
30+
yield check_trace, model, trace, n, step, start
3231

3332
def test_multitrace():
3433
if not test_parallel:
3534
return
36-
start, step,_ = simple_init()
35+
model, start, step,_ = simple_init()
3736
trace = None
3837
for n in [20, 1000]:
3938

40-
yield check_multi_trace, trace, n, step, start
39+
yield check_multi_trace, model, trace, n, step, start
4140

4241

4342

44-
def check_multi_trace(trace, n, step, start):
43+
def check_multi_trace(model, trace, n, step, start):
4544

46-
#try using a trace object a few times
4745
for i in range(2):
48-
trace, _, _ = psample(n, step, start, trace)
46+
trace = psample(model, n, step, start, trace)
47+
4948

5049
for (var, val) in start.iteritems():
50+
print [len(tr.samples[var].vals) for tr in trace.traces]
5151
for t in trace[var]:
5252
assert np.shape(t) == (n*(i+1),) + np.shape(val)
5353

0 commit comments

Comments
 (0)