Skip to content

Commit 8b75eaf

Browse files
committed
revert bayes_nonconj to main
1 parent 6bb8dfd commit 8b75eaf

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

lectures/bayes_nonconj.md

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -556,20 +556,17 @@ class BayesianInference:
556556
Computes numerically the posterior distribution with beta prior parametrized by (alpha0, beta0)
557557
given data using MCMC
558558
"""
559-
560-
# Convert data to float32
561-
data = np.asarray(data, dtype=np.float32)
562-
563559
# use pyro
564560
if self.solver=='pyro':
565-
561+
# tensorize
562+
data = torch.tensor(data)
566563
nuts_kernel = NUTS(self.model)
567564
mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=num_warmup, disable_progbar=True)
568565
mcmc.run(data)
569566
570567
# use numpyro
571568
elif self.solver=='numpyro':
572-
569+
data = np.array(data, dtype=float)
573570
nuts_kernel = nNUTS(self.model)
574571
mcmc = nMCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, progress_bar=False)
575572
mcmc.run(self.rng_key, data=data)
@@ -592,9 +589,9 @@ class BayesianInference:
592589
pyro.sample('theta', dist.Beta(alpha_q, beta_q))
593590
594591
else:
595-
alpha_q = numpyro.param('alpha_q', 10.0,
592+
alpha_q = numpyro.param('alpha_q', 10,
596593
constraint=nconstraints.positive)
597-
beta_q = numpyro.param('beta_q', 10.0,
594+
beta_q = numpyro.param('beta_q', 10,
598595
constraint=nconstraints.positive)
599596
600597
numpyro.sample('theta', ndist.Beta(alpha_q, beta_q))
@@ -652,31 +649,33 @@ class BayesianInference:
652649
params : the learned parameters for guide
653650
losses : a vector of loss at each step
654651
"""
655-
# Convert data to float32
656-
data = np.asarray(data, dtype=np.float32)
657652
658653
# initiate SVI
659654
svi = self.SVI_init(guide_dist=guide_dist)
660655
661656
# do gradient steps
662-
if self.solver == 'pyro':
657+
if self.solver=='pyro':
658+
# tensorize data
659+
if not torch.is_tensor(data):
660+
data = torch.tensor(data)
663661
# store loss vector
664-
losses = np.zeros(n_steps, dtype=np.float32)
662+
losses = np.zeros(n_steps)
665663
for step in range(n_steps):
666664
losses[step] = svi.step(data)
667665
668666
# pyro only supports beta VI distribution
669667
params = {
670668
'alpha_q': pyro.param('alpha_q').item(),
671669
'beta_q': pyro.param('beta_q').item()
672-
}
670+
}
673671
674-
elif self.solver == 'numpyro':
672+
elif self.solver=='numpyro':
673+
data = np.array(data, dtype=float)
675674
result = svi.run(self.rng_key, n_steps, data, progress_bar=False)
676-
params = {
677-
key: np.asarray(value, dtype=np.float32) for key, value in result.params.items()
678-
}
679-
losses = np.asarray(result.losses, dtype=np.float32)
675+
params = dict(
676+
(key, np.asarray(value)) for key, value in result.params.items()
677+
)
678+
losses = np.asarray(result.losses)
680679
681680
return params, losses
682681
```
@@ -967,18 +966,18 @@ We first initialize the `BayesianInference` classes and then can directly call `
967966
```{code-cell} ipython3
968967
# Initialize BayesianInference classes
969968
# try uniform
970-
STD_UNIFORM_pyro = BayesianInference(param=(0.0,1.0), name_dist='uniform', solver='pyro')
969+
STD_UNIFORM_pyro = BayesianInference(param=(0,1), name_dist='uniform', solver='pyro')
971970
UNIFORM_numpyro = BayesianInference(param=(0.2,0.7), name_dist='uniform', solver='numpyro')
972971
973972
# try truncated lognormal
974-
LOGNORMAL_numpyro = BayesianInference(param=(0.0,2.0), name_dist='lognormal', solver='numpyro')
975-
LOGNORMAL_pyro = BayesianInference(param=(0.0,2.0), name_dist='lognormal', solver='pyro')
973+
LOGNORMAL_numpyro = BayesianInference(param=(0,2), name_dist='lognormal', solver='numpyro')
974+
LOGNORMAL_pyro = BayesianInference(param=(0,2), name_dist='lognormal', solver='pyro')
976975
977976
# try von Mises
978977
# shifted von Mises
979-
VONMISES_numpyro = BayesianInference(param=10.0, name_dist='vonMises', solver='numpyro')
978+
VONMISES_numpyro = BayesianInference(param=10, name_dist='vonMises', solver='numpyro')
980979
# truncated von Mises
981-
VONMISES_pyro = BayesianInference(param=40.0, name_dist='vonMises', solver='pyro')
980+
VONMISES_pyro = BayesianInference(param=40, name_dist='vonMises', solver='pyro')
982981
983982
# try laplace
984983
LAPLACE_numpyro = BayesianInference(param=(0.5, 0.07), name_dist='laplace', solver='numpyro')

0 commit comments

Comments
 (0)