Skip to content

Commit 6bb8dfd

Browse files
committed
fix the same issue for MCMC plots
1 parent f6a52e3 commit 6bb8dfd

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

lectures/bayes_nonconj.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,9 @@ class BayesianInference:
556556
Computes numerically the posterior distribution with beta prior parametrized by (alpha0, beta0)
557557
given data using MCMC
558558
"""
559-
# tensorize
560-
data = torch.tensor(data)
559+
560+
# Convert data to float32
561+
data = np.asarray(data, dtype=np.float32)
561562
562563
# use pyro
563564
if self.solver=='pyro':
@@ -966,18 +967,18 @@ We first initialize the `BayesianInference` classes and then can directly call `
966967
```{code-cell} ipython3
967968
# Initialize BayesianInference classes
968969
# try uniform
969-
STD_UNIFORM_pyro = BayesianInference(param=(0,1), name_dist='uniform', solver='pyro')
970+
STD_UNIFORM_pyro = BayesianInference(param=(0.0,1.0), name_dist='uniform', solver='pyro')
970971
UNIFORM_numpyro = BayesianInference(param=(0.2,0.7), name_dist='uniform', solver='numpyro')
971972
972973
# try truncated lognormal
973-
LOGNORMAL_numpyro = BayesianInference(param=(0,2), name_dist='lognormal', solver='numpyro')
974-
LOGNORMAL_pyro = BayesianInference(param=(0,2), name_dist='lognormal', solver='pyro')
974+
LOGNORMAL_numpyro = BayesianInference(param=(0.0,2.0), name_dist='lognormal', solver='numpyro')
975+
LOGNORMAL_pyro = BayesianInference(param=(0.0,2.0), name_dist='lognormal', solver='pyro')
975976
976977
# try von Mises
977978
# shifted von Mises
978-
VONMISES_numpyro = BayesianInference(param=10, name_dist='vonMises', solver='numpyro')
979+
VONMISES_numpyro = BayesianInference(param=10.0, name_dist='vonMises', solver='numpyro')
979980
# truncated von Mises
980-
VONMISES_pyro = BayesianInference(param=40, name_dist='vonMises', solver='pyro')
981+
VONMISES_pyro = BayesianInference(param=40.0, name_dist='vonMises', solver='pyro')
981982
982983
# try laplace
983984
LAPLACE_numpyro = BayesianInference(param=(0.5, 0.07), name_dist='laplace', solver='numpyro')

0 commit comments

Comments
 (0)