@@ -99,6 +99,9 @@ class MvNormal(Continuous):
99
99
def __init__ (self , mu , cov = None , tau = None , chol = None , lower = True ,
100
100
* args , ** kwargs ):
101
101
super (MvNormal , self ).__init__ (* args , ** kwargs )
102
+ if len (self .shape ) > 2 :
103
+ raise ValueError ("Only 1 or 2 dimensions are allowed." )
104
+
102
105
if not lower :
103
106
chol = chol .T
104
107
if len ([i for i in [tau , cov , chol ] if i is not None ]) != 1 :
@@ -122,6 +125,7 @@ def __init__(self, mu, cov=None, tau=None, chol=None, lower=True,
122
125
raise ValueError ('cov must be two dimensional.' )
123
126
self .chol_cov = cholesky (cov )
124
127
self .cov = cov
128
+ self ._n = self .cov .shape [- 1 ]
125
129
elif tau is not None :
126
130
self .k = tau .shape [0 ]
127
131
self ._cov_type = 'tau'
@@ -130,13 +134,15 @@ def __init__(self, mu, cov=None, tau=None, chol=None, lower=True,
130
134
raise ValueError ('tau must be two dimensional.' )
131
135
self .chol_tau = cholesky (tau )
132
136
self .tau = tau
137
+ self ._n = self .tau .shape [- 1 ]
133
138
else :
134
139
self .k = chol .shape [0 ]
135
140
self ._cov_type = 'chol'
136
141
if chol .ndim != 2 :
137
142
raise ValueError ('chol must be two dimensional.' )
138
143
self .chol_cov = tt .as_tensor_variable (chol )
139
-
144
+ self ._n = self .chol_cov .shape [- 1 ]
145
+
140
146
def random (self , point = None , size = None ):
141
147
if size is None :
142
148
size = []
@@ -148,6 +154,9 @@ def random(self, point=None, size=None):
148
154
149
155
if self ._cov_type == 'cov' :
150
156
mu , cov = draw_values ([self .mu , self .cov ], point = point )
157
+ if mu .shape != cov [0 ].shape :
158
+ raise ValueError ("Shapes for mu an cov don't match" )
159
+
151
160
try :
152
161
dist = stats .multivariate_normal (
153
162
mean = mu , cov = cov , allow_singular = True )
@@ -157,11 +166,17 @@ def random(self, point=None, size=None):
157
166
return dist .rvs (size )
158
167
elif self ._cov_type == 'chol' :
159
168
mu , chol = draw_values ([self .mu , self .chol_cov ], point = point )
169
+ if mu .shape != chol [0 ].shape :
170
+ raise ValueError ("Shapes for mu an chol don't match" )
171
+
160
172
size .append (mu .shape [0 ])
161
173
standard_normal = np .random .standard_normal (size )
162
174
return mu + np .dot (standard_normal , chol .T )
163
175
else :
164
176
mu , tau = draw_values ([self .mu , self .tau ], point = point )
177
+ if mu .shape != tau [0 ].shape :
178
+ raise ValueError ("Shapes for mu an tau don't match" )
179
+
165
180
size .append (mu .shape [0 ])
166
181
standard_normal = np .random .standard_normal (size )
167
182
try :
@@ -171,23 +186,32 @@ def random(self, point=None, size=None):
171
186
transformed = linalg .solve_triangular (
172
187
chol , standard_normal .T , lower = True )
173
188
return mu + transformed .T
174
-
189
+
175
190
def logp (self , value ):
176
191
mu = self .mu
177
- k = mu .shape [- 1 ]
178
-
179
- value = value .reshape ((- 1 , k ))
192
+ if value .ndim > 2 or value .ndim == 0 :
193
+ raise ValueError ('Invalid dimension for value: %s' % value .ndim )
194
+ if value .ndim == 1 :
195
+ onedim = True
196
+ value = value [None , :]
197
+ else :
198
+ onedim = False
199
+
180
200
delta = value - mu
181
-
201
+
182
202
if self ._cov_type == 'cov' :
183
203
# Use this when Theano#5908 is released.
184
204
# return MvNormalLogp()(self.cov, delta)
185
- return self ._logp_cov (delta )
205
+ logp = self ._logp_cov (delta )
186
206
elif self ._cov_type == 'tau' :
187
- return self ._logp_tau (delta )
207
+ logp = self ._logp_tau (delta )
188
208
else :
189
- return self ._logp_chol (delta )
190
-
209
+ logp = self ._logp_chol (delta )
210
+
211
+ if onedim :
212
+ return logp [0 ]
213
+ return logp
214
+
191
215
def _logp_chol (self , delta ):
192
216
chol_cov = self .chol_cov
193
217
n , k = delta .shape
@@ -200,10 +224,10 @@ def _logp_chol(self, delta):
200
224
chol_cov = tt .switch (ok , chol_cov , 1 )
201
225
202
226
delta_trans = self .solve_lower (chol_cov , delta .T )
203
-
204
- result = n * k * np .log (2 * np .pi )
205
- result += 2.0 * n * tt .sum (tt .log (diag ))
206
- result += (delta_trans ** 2 ).sum ()
227
+
228
+ result = k * np .log (2 * np .pi )
229
+ result += 2.0 * tt .sum (tt .log (diag ))
230
+ result += (delta_trans ** 2 ).sum (axis = 0 )
207
231
result = - 0.5 * result
208
232
return bound (result , ok )
209
233
@@ -217,10 +241,10 @@ def _logp_cov(self, delta):
217
241
chol_cov = tt .switch (ok , chol_cov , 1 )
218
242
diag = tt .nlinalg .diag (chol_cov )
219
243
delta_trans = self .solve_lower (chol_cov , delta .T )
220
-
221
- result = n * k * tt .log (2 * np .pi )
222
- result += 2.0 * n * tt .sum (tt .log (diag ))
223
- result += (delta_trans ** 2 ).sum ()
244
+
245
+ result = k * tt .log (2 * np .pi )
246
+ result += 2.0 * tt .sum (tt .log (diag ))
247
+ result += (delta_trans ** 2 ).sum (axis = 0 )
224
248
result = - 0.5 * result
225
249
return bound (result , ok )
226
250
@@ -234,10 +258,10 @@ def _logp_tau(self, delta):
234
258
chol_tau = tt .switch (ok , chol_tau , 1 )
235
259
diag = tt .nlinalg .diag (chol_tau )
236
260
delta_trans = tt .dot (chol_tau .T , delta .T )
237
-
238
- result = n * k * tt .log (2 * np .pi )
239
- result -= 2.0 * n * tt .sum (tt .log (diag ))
240
- result += (delta_trans ** 2 ).sum ()
261
+
262
+ result = k * tt .log (2 * np .pi )
263
+ result -= 2.0 * tt .sum (tt .log (diag ))
264
+ result += (delta_trans ** 2 ).sum (axis = 0 )
241
265
result = - 0.5 * result
242
266
return bound (result , ok )
243
267
@@ -247,7 +271,7 @@ def _repr_latex_(self, name=None, dist=None):
247
271
mu = dist .mu
248
272
try :
249
273
cov = dist .cov
250
- except AttributeErrir :
274
+ except AttributeError :
251
275
cov = dist .chol_cov
252
276
return r'${} \sim \text{{MvNormal}}(\mathit{{mu}}={}, \mathit{{cov}}={})$' .format (name ,
253
277
get_variable_name (mu ),
0 commit comments