@@ -556,20 +556,17 @@ class BayesianInference:
556556 Computes numerically the posterior distribution with beta prior parametrized by (alpha0, beta0)
557557 given data using MCMC
558558 """
559-
560- # Convert data to float32
561- data = np.asarray(data, dtype=np.float32)
562-
563559 # use pyro
564560 if self.solver=='pyro':
565-
561+ # tensorize
562+ data = torch.tensor(data)
566563 nuts_kernel = NUTS(self.model)
567564 mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=num_warmup, disable_progbar=True)
568565 mcmc.run(data)
569566
570567 # use numpyro
571568 elif self.solver=='numpyro':
572-
569+ data = np.array(data, dtype=float)
573570 nuts_kernel = nNUTS(self.model)
574571 mcmc = nMCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, progress_bar=False)
575572 mcmc.run(self.rng_key, data=data)
@@ -592,9 +589,9 @@ class BayesianInference:
592589 pyro.sample('theta', dist.Beta(alpha_q, beta_q))
593590
594591 else:
595- alpha_q = numpyro.param('alpha_q', 10.0 ,
592+ alpha_q = numpyro.param('alpha_q', 10,
596593 constraint=nconstraints.positive)
597- beta_q = numpyro.param('beta_q', 10.0 ,
594+ beta_q = numpyro.param('beta_q', 10,
598595 constraint=nconstraints.positive)
599596
600597 numpyro.sample('theta', ndist.Beta(alpha_q, beta_q))
@@ -652,31 +649,33 @@ class BayesianInference:
652649 params : the learned parameters for guide
653650 losses : a vector of loss at each step
654651 """
655- # Convert data to float32
656- data = np.asarray(data, dtype=np.float32)
657652
658653 # initiate SVI
659654 svi = self.SVI_init(guide_dist=guide_dist)
660655
661656 # do gradient steps
662- if self.solver == 'pyro':
657+ if self.solver=='pyro':
658+ # tensorize data
659+ if not torch.is_tensor(data):
660+ data = torch.tensor(data)
663661 # store loss vector
664- losses = np.zeros(n_steps, dtype=np.float32 )
662+ losses = np.zeros(n_steps)
665663 for step in range(n_steps):
666664 losses[step] = svi.step(data)
667665
668666 # pyro only supports beta VI distribution
669667 params = {
670668 'alpha_q': pyro.param('alpha_q').item(),
671669 'beta_q': pyro.param('beta_q').item()
672- }
670+ }
673671
674- elif self.solver == 'numpyro':
672+ elif self.solver=='numpyro':
673+ data = np.array(data, dtype=float)
675674 result = svi.run(self.rng_key, n_steps, data, progress_bar=False)
676- params = {
677- key: np.asarray(value, dtype=np.float32 ) for key, value in result.params.items()
678- }
679- losses = np.asarray(result.losses, dtype=np.float32 )
675+ params = dict(
676+ ( key, np.asarray(value) ) for key, value in result.params.items()
677+ )
678+ losses = np.asarray(result.losses)
680679
681680 return params, losses
682681```
@@ -967,18 +966,18 @@ We first initialize the `BayesianInference` classes and then can directly call `
967966```{code-cell} ipython3
968967# Initialize BayesianInference classes
969968# try uniform
970- STD_UNIFORM_pyro = BayesianInference(param=(0.0,1.0 ), name_dist='uniform', solver='pyro')
969+ STD_UNIFORM_pyro = BayesianInference(param=(0,1 ), name_dist='uniform', solver='pyro')
971970UNIFORM_numpyro = BayesianInference(param=(0.2,0.7), name_dist='uniform', solver='numpyro')
972971
973972# try truncated lognormal
974- LOGNORMAL_numpyro = BayesianInference(param=(0.0,2.0 ), name_dist='lognormal', solver='numpyro')
975- LOGNORMAL_pyro = BayesianInference(param=(0.0,2.0 ), name_dist='lognormal', solver='pyro')
973+ LOGNORMAL_numpyro = BayesianInference(param=(0,2 ), name_dist='lognormal', solver='numpyro')
974+ LOGNORMAL_pyro = BayesianInference(param=(0,2 ), name_dist='lognormal', solver='pyro')
976975
977976# try von Mises
978977# shifted von Mises
979- VONMISES_numpyro = BayesianInference(param=10.0 , name_dist='vonMises', solver='numpyro')
978+ VONMISES_numpyro = BayesianInference(param=10, name_dist='vonMises', solver='numpyro')
980979# truncated von Mises
981- VONMISES_pyro = BayesianInference(param=40.0 , name_dist='vonMises', solver='pyro')
980+ VONMISES_pyro = BayesianInference(param=40, name_dist='vonMises', solver='pyro')
982981
983982# try laplace
984983LAPLACE_numpyro = BayesianInference(param=(0.5, 0.07), name_dist='laplace', solver='numpyro')
0 commit comments