Skip to content

Commit 298fa23

Browse files
committed
move to floatX
1 parent 51108d4 commit 298fa23

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,26 +230,26 @@ def test_replacements_in_sample_node_aevb(three_var_aevb_approx, aevb_initial):
230230

231231
def test_vae():
232232
minibatch_size = 10
233-
data = np.random.rand(100).astype('float32')
233+
data = pm.floatX(np.random.rand(100))
234234
x_mini = pm.Minibatch(data, minibatch_size)
235235
x_inp = tt.vector()
236236
x_inp.tag.test_value = data[:minibatch_size]
237237

238-
ae = theano.shared(np.asarray([.1, .1], 'float32'))
239-
be = theano.shared(np.asarray(1., dtype='float32'))
238+
ae = theano.shared(pm.floatX([.1, .1]))
239+
be = theano.shared(pm.floatX(1.))
240240

241-
ad = theano.shared(np.asarray(1., dtype='float32'))
242-
bd = theano.shared(np.asarray(1., dtype='float32'))
241+
ad = theano.shared(pm.floatX(1.))
242+
bd = theano.shared(pm.floatX(1.))
243243

244244
enc = x_inp.dimshuffle(0, 'x') * ae.dimshuffle('x', 0) + be
245245
mu, rho = enc[:, 0], enc[:, 1]
246246

247247
with pm.Model():
248248
# Hidden variables
249-
zs = pm.Normal('zs', mu=0, sd=1, shape=minibatch_size, dtype='float32')
249+
zs = pm.Normal('zs', mu=0, sd=1, shape=minibatch_size)
250250
dec = zs * ad + bd
251251
# Observation model
252-
pm.Normal('xs_', mu=dec, sd=0.1, observed=x_inp, dtype='float32')
252+
pm.Normal('xs_', mu=dec, sd=0.1, observed=x_inp)
253253

254254
pm.fit(1, local_rv={zs: dict(mu=mu, rho=rho)},
255255
more_replacements={x_inp: x_mini}, more_obj_params=[ae, be, ad, bd])

0 commit comments

Comments
 (0)