2
2
import theano .tensor as tt
3
3
import numpy as np
4
4
from functools import reduce
5
+ from operator import mul , add
5
6
6
7
__all__ = ['ExpQuad' ,
7
8
'RatQuad' ,
@@ -91,9 +92,8 @@ def __array_wrap__(self, result):
91
92
92
93
class Combination (Covariance ):
93
94
def __init__ (self , factor_list ):
94
- input_dim = np .max ([factor .input_dim for factor in
95
- filter (lambda x : isinstance (x , Covariance ),
96
- factor_list )])
95
+ input_dim = max ([factor .input_dim for factor in factor_list
96
+ if isinstance (factor , Covariance )])
97
97
super (Combination , self ).__init__ (input_dim = input_dim )
98
98
self .factor_list = []
99
99
for factor in factor_list :
@@ -103,45 +103,36 @@ def __init__(self, factor_list):
103
103
self .factor_list .append (factor )
104
104
105
105
def merge_factors (self , X , Xs = None , diag = False ):
106
- # this function makes sure diag=True is handled properly
107
106
factor_list = []
108
107
for factor in self .factor_list :
109
-
110
- # if factor is a Covariance
108
+ # make sure diag=True is handled properly
111
109
if isinstance (factor , Covariance ):
112
110
factor_list .append (factor (X , Xs , diag ))
113
- continue
114
-
115
- # if factor is a numpy array
116
- if isinstance (factor , np .ndarray ):
117
- if np .ndim (factor ) == 2 :
118
- if diag :
119
- factor_list .append (np .diag (factor ))
120
- continue
121
-
122
- # if factor is a theano variable with ndim attribute
123
- if isinstance (factor , (tt .TensorConstant ,
111
+ elif isinstance (factor , np .ndarray ):
112
+ if np .ndim (factor ) == 2 and diag :
113
+ factor_list .append (np .diag (factor ))
114
+ else :
115
+ factor_list .append (factor )
116
+ elif isinstance (factor , (tt .TensorConstant ,
124
117
tt .TensorVariable ,
125
118
tt .sharedvar .TensorSharedVariable )):
126
- if factor .ndim == 2 :
127
- if diag :
128
- factor_list .append (tt .diag (factor ))
129
- continue
130
-
131
- # othewise
132
- factor_list .append (factor )
133
-
119
+ if factor .ndim == 2 and diag :
120
+ factor_list .append (tt .diag (factor ))
121
+ else :
122
+ factor_list .append (factor )
123
+ else :
124
+ factor_list .append (factor )
134
125
return factor_list
135
126
136
127
137
128
class Add (Combination ):
138
129
def __call__ (self , X , Xs = None , diag = False ):
139
- return reduce (( lambda x , y : x + y ) , self .merge_factors (X , Xs , diag ))
130
+ return reduce (add , self .merge_factors (X , Xs , diag ))
140
131
141
132
142
133
class Prod (Combination ):
143
134
def __call__ (self , X , Xs = None , diag = False ):
144
- return reduce (( lambda x , y : x * y ) , self .merge_factors (X , Xs , diag ))
135
+ return reduce (mul , self .merge_factors (X , Xs , diag ))
145
136
146
137
147
138
class Constant (Covariance ):
@@ -205,7 +196,7 @@ def __init__(self, input_dim, ls=None, ls_inv=None, active_dims=None):
205
196
if (ls is None and ls_inv is None ) or (ls is not None and ls_inv is not None ):
206
197
raise ValueError ("Only one of 'ls' or 'ls_inv' must be provided" )
207
198
elif ls_inv is not None :
208
- if isinstance (ls_inv , (np . ndarray , list , tuple )):
199
+ if isinstance (ls_inv , (list , tuple )):
209
200
ls = 1.0 / np .asarray (ls_inv )
210
201
else :
211
202
ls = 1.0 / ls_inv
@@ -229,7 +220,8 @@ def euclidean_dist(self, X, Xs):
229
220
return tt .sqrt (r2 + 1e-12 )
230
221
231
222
def diag (self , X ):
232
- return tt .ones (tt .stack ([X .shape [0 ], ]))
223
+ return tt .alloc (1.0 , X .shape [0 ])
224
+ #return tt.ones(tt.stack([X.shape[0], ]))
233
225
234
226
def full (self , X , Xs = None ):
235
227
raise NotImplementedError
@@ -274,7 +266,7 @@ class RatQuad(Stationary):
274
266
k(x, x') = \left(1 + \frac{(x - x')^2}{2\alpha\ell^2} \right)^{-\alpha}
275
267
"""
276
268
277
- def __init__ (self , input_dim , alpha , ls , ls_inv , active_dims = None ):
269
+ def __init__ (self , input_dim , alpha , ls = None , ls_inv = None , active_dims = None ):
278
270
super (RatQuad , self ).__init__ (input_dim , ls , ls_inv , active_dims )
279
271
self .alpha = alpha
280
272
@@ -337,7 +329,7 @@ class Cosine(Stationary):
337
329
The Cosine kernel.
338
330
339
331
.. math::
340
- k(x, x') = \mathrm{cos}\left( \frac{||x - x'||}{ \ell^2} \right)
332
+ k(x, x') = \mathrm{cos}\left( \pi \ frac{||x - x'||}{ \ell^2} \right)
341
333
"""
342
334
343
335
def full (self , X , Xs = None ):
0 commit comments