Skip to content

Commit 16a8cda

Browse files
committed
small changes to means
1 parent a4bb63b commit 16a8cda

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pymc3/gp/mean.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class Zero(Mean):
3232
"""
3333

3434
def __call__(self, X):
35-
return tt.zeros(tt.stack([X.shape[0], ]), dtype='float32')
36-
35+
return tt.alloc(0.0, X.shape[0])
3736

3837
class Constant(Mean):
3938
R"""
@@ -50,7 +49,7 @@ def __init__(self, c=0):
5049
self.c = c
5150

5251
def __call__(self, X):
53-
return tt.ones(tt.stack([X.shape[0], ])) * self.c
52+
return tt.alloc(1.0, X.shape[0]) * self.c
5453

5554

5655
class Linear(Mean):
@@ -71,7 +70,7 @@ def __init__(self, coeffs, intercept=0):
7170
self.A = coeffs
7271

7372
def __call__(self, X):
74-
return (tt.dot(X, self.A) + self.b).squeeze()
73+
return tt.squeeze(tt.dot(X, self.A) + self.b)
7574

7675

7776
class Add(Mean):

0 commit comments

Comments
 (0)