Skip to content

Commit 1c7b78d

Browse files
committed
fixed up tests
1 parent 02c5132 commit 1c7b78d

File tree

6 files changed

+34
-14
lines changed

6 files changed

+34
-14
lines changed

pymc/distributions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from continuous import *
22
from discrete import *
33
from transforms import *
4+
from distribution import *
45

56
import multivariate
67

pymc/distributions/continuous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,10 @@ def logp(value):
344344

345345
return locals()
346346

347-
def Tpos(nu, mu=0, lam=1):
347+
def Tpos(*args, **kwargs):
348348
"""
349349
Student-t distribution bounded at 0
350350
see T
351351
"""
352-
return Bound(T(nu, mu, lam), 0)
352+
return Bound(T(*args,**kwargs), 0)
353353

pymc/distributions/distribution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from ..quickclass import *
44
from ..model import *
5+
__all__ = ['DensityDist', 'TensorDist', 'tensordist', 'continuous', 'discrete', 'arbitrary']
56

67
class Distribution(object):
78
def __new__(cls, *args, **kwargs):

pymc/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def d2logpc(model, vars = None):
8787

8888
@property
8989
def test_point(self):
90-
return Point( (var, var.tag.test_value) for var in self.vars)
90+
return Point(self, ((var, var.tag.test_value) for var in self.vars))
9191

9292
@property
9393
def cont_vars(model):
@@ -175,7 +175,7 @@ def cont_inputs(f):
175175

176176
def gradient1(f, v):
177177
"""flat gradient of f wrt v"""
178-
return t.flatten(t.grad(f, v))
178+
return t.flatten(t.grad(f, v, disconnected_inputs='warn'))
179179

180180
def gradient(f, vars = None):
181181
if not vars:

pymc/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def sample(model, draws, step, start = None, trace = None, progress_bar = True):
3838
draws = int(draws)
3939
if start is None:
4040
start = trace[-1]
41-
point = Point(start)
41+
point = Point(model, start)
4242

4343
if not hasattr(trace, 'record'):
4444
if trace is None:

tests/test_distributions.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,41 @@ def test_poisson():
6868
def test_constantdist():
6969
checkd(ConstantDist, I, {'c' : I})
7070

71+
def test_zeroinflatedpoisson():
72+
checkd(ZeroInflatedPoisson, I, {'theta' : Rplus, 'z' : Bool})
73+
74+
75+
def test_densitydist():
76+
def logp(x):
77+
return -log(2*.5) - abs(x-.5)/.5
78+
79+
checkd(DensityDist,R, {}, extra_args = {'logp' : logp})
80+
81+
82+
def test_addpotential():
83+
with Model() as model:
84+
x = Normal('x', 1,1)
85+
model.AddPotential(-x**2)
86+
87+
check_dlogp(model, x, [R])
7188

72-
def checkd(distfam, valuedomain, vardomains, check_int = True, check_der = True):
7389

74-
m = Model()
90+
def checkd(distfam, valuedomain, vardomains, check_int = True, check_der = True, extra_args = {}):
7591

76-
with m:
92+
with Model() as m:
7793
vars = dict((v , Flat(v, dtype = dom.dtype)) for v,dom in vardomains.iteritems())
94+
vars.update(extra_args)
95+
print vars
7896
value = distfam('value', testval = valuedomain[len(valuedomain)//2], **vars)
7997

80-
vardomains['value'] = np.array(valuedomain)
98+
vardomains['value'] = np.array(valuedomain)
8199

82-
domains = [np.array(vardomains[str(v)]) for v in m.vars]
100+
domains = [np.array(vardomains[str(v)]) for v in m.vars]
83101

84-
if check_int:
85-
check_int_to_1(m, value, domains)
86-
if check_der:
87-
check_dlogp(m, value, domains)
102+
if check_int:
103+
check_int_to_1(m, value, domains)
104+
if check_der:
105+
check_dlogp(m, value, domains)
88106

89107

90108
def check_int_to_1(model, value, domains):

0 commit comments

Comments
 (0)