@@ -230,26 +230,26 @@ def test_replacements_in_sample_node_aevb(three_var_aevb_approx, aevb_initial):
230
230
231
231
def test_vae ():
232
232
minibatch_size = 10
233
- data = np .random .rand (100 ). astype ( 'float32' )
233
+ data = pm . floatX ( np .random .rand (100 ))
234
234
x_mini = pm .Minibatch (data , minibatch_size )
235
235
x_inp = tt .vector ()
236
236
x_inp .tag .test_value = data [:minibatch_size ]
237
237
238
- ae = theano .shared (np . asarray ([.1 , .1 ], 'float32' ))
239
- be = theano .shared (np . asarray (1. , dtype = 'float32' ))
238
+ ae = theano .shared (pm . floatX ([.1 , .1 ]))
239
+ be = theano .shared (pm . floatX (1. ))
240
240
241
- ad = theano .shared (np . asarray (1. , dtype = 'float32' ))
242
- bd = theano .shared (np . asarray (1. , dtype = 'float32' ))
241
+ ad = theano .shared (pm . floatX (1. ))
242
+ bd = theano .shared (pm . floatX (1. ))
243
243
244
244
enc = x_inp .dimshuffle (0 , 'x' ) * ae .dimshuffle ('x' , 0 ) + be
245
245
mu , rho = enc [:, 0 ], enc [:, 1 ]
246
246
247
247
with pm .Model ():
248
248
# Hidden variables
249
- zs = pm .Normal ('zs' , mu = 0 , sd = 1 , shape = minibatch_size , dtype = 'float32' )
249
+ zs = pm .Normal ('zs' , mu = 0 , sd = 1 , shape = minibatch_size )
250
250
dec = zs * ad + bd
251
251
# Observation model
252
- pm .Normal ('xs_' , mu = dec , sd = 0.1 , observed = x_inp , dtype = 'float32' )
252
+ pm .Normal ('xs_' , mu = dec , sd = 0.1 , observed = x_inp )
253
253
254
254
pm .fit (1 , local_rv = {zs : dict (mu = mu , rho = rho )},
255
255
more_replacements = {x_inp : x_mini }, more_obj_params = [ae , be , ad , bd ])
0 commit comments