Skip to content

Commit 3056dfb

Browse files
committed
modify figure style
1 parent 9e92870 commit 3056dfb

File tree

1 file changed

+141
-26
lines changed

1 file changed

+141
-26
lines changed

lectures/bayes_nonconj.md

Lines changed: 141 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def TruncatedLogNormal_trans(loc, scale):
228228
"""
229229
base_dist = ndist.TruncatedNormal(
230230
low=jnp.log(0), high=jnp.log(1), loc=loc, scale=scale
231-
) #TODO:is it fine to use log(0)?
231+
)
232232
return ndist.TransformedDistribution(base_dist, ndist.transforms.ExpTransform())
233233
234234
@@ -279,7 +279,7 @@ Consider a **guide distribution** $q_{\phi}(\theta)$ parameterized by $\phi$ tha
279279
We choose parameters $\phi$ of the guide distribution to minimize a Kullback-Leibler (KL) divergence between the approximate posterior $q_{\phi}(\theta)$ and the posterior:
280280
281281
$$
282-
D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) \equiv -\int d\theta q(\theta;\phi)\log\frac{p(\theta\mid Y)}{q(\theta;\phi)}
282+
D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) \equiv -\int q(\theta;\phi)\log\frac{p(\theta\mid Y)}{q(\theta;\phi)} d\theta
283283
$$
284284
285285
Thus, we want a **variational distribution** $q$ that solves
@@ -291,7 +291,7 @@ $$
291291
Note that
292292
293293
$$
294-
\begin{aligned}D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) & =-\int d\theta q(\theta;\phi)\log\frac{P(\theta\mid Y)}{q(\theta;\phi)}\\
294+
\begin{aligned}D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) & =-\int q(\theta;\phi)\log\frac{P(\theta\mid Y)}{q(\theta;\phi)} d\theta\\
295295
& =-\int q(\theta)\log\frac{\frac{p(\theta,Y)}{p(Y)}}{q(\theta)} d\theta\\
296296
& =-\int q(\theta)\log\frac{p(\theta,Y)}{p(\theta)q(Y)} d\theta\\
297297
& =-\int q(\theta)\left[\log\frac{p(\theta,Y)}{q(\theta)}-\log p(Y)\right] d\theta\\
@@ -533,10 +533,26 @@ Let's see how well our sampling algorithm does in approximating
533533
To examine our alternative prior distributions, we'll plot approximate prior distributions below by calling the `show_prior` method.
534534
535535
```{code-cell} ipython3
536+
---
537+
mystnb:
538+
figure:
539+
caption: |
540+
Truncated log normal distribution
541+
name: fig_lognormal_dist
542+
---
536543
# truncated log normal
537544
exampleLN = BayesianInference(param=(0, 2), name_dist="lognormal")
538545
exampleLN.show_prior(size=100000, bins=20)
546+
```
539547
548+
```{code-cell} ipython3
549+
---
550+
mystnb:
551+
figure:
552+
caption: |
553+
Truncated uniform distribution
554+
name: fig_uniform_dist
555+
---
540556
# truncated uniform
541557
exampleUN = BayesianInference(param=(0.1, 0.8), name_dist="uniform")
542558
exampleUN.show_prior(size=100000, bins=20)
@@ -548,6 +564,13 @@ Now let's see how well things work with von Mises distributions.
548564
549565
```{code-cell} ipython3
550566
# shifted von Mises
567+
---
568+
mystnb:
569+
figure:
570+
caption: |
571+
Shifted von Mises distribution
572+
name: fig_vonmises_dist
573+
---
551574
exampleVM = BayesianInference(param=10, name_dist="vonMises")
552575
exampleVM.show_prior(size=100000, bins=20)
553576
```
@@ -557,6 +580,13 @@ The graphs look good too.
557580
Now let's try with a Laplace distribution.
558581
559582
```{code-cell} ipython3
583+
---
584+
mystnb:
585+
figure:
586+
caption: |
587+
Truncated Laplace distribution
588+
name: fig_laplace_dist
589+
---
560590
# truncated Laplace
561591
exampleLP = BayesianInference(param=(0.5, 0.05), name_dist="laplace")
562592
exampleLP.show_prior(size=100000, bins=40)
@@ -609,7 +639,7 @@ class BayesianInferencePlot:
609639
self.data = simulate_draw(theta, N_max)
610640
611641
def MCMC_plot(self, num_samples, num_warmup=1000):
612-
fig, ax = plt.subplots(figsize=(10, 6))
642+
fig, ax = plt.subplots()
613643
614644
# plot prior
615645
prior_sample = self.BayesianInferenceClass.show_prior(disp_plot=0)
@@ -641,14 +671,11 @@ class BayesianInferencePlot:
641671
label=f"Posterior with $n={n}$",
642672
)
643673
ax.legend(loc="upper left")
644-
ax.set_title("MCMC sampling density of posterior distributions", fontsize=15)
645674
plt.xlim(0, 1)
646675
plt.show()
647676
648677
def SVI_fitting(self, guide_dist, params):
649-
"""
650-
Fit the beta/truncnormal curve using parameters trained by SVI.
651-
"""
678+
"""Fit the beta/truncnormal curve using parameters trained by SVI."""
652679
# create x axis
653680
xaxis = np.linspace(0, 1, 1000)
654681
if guide_dist == "beta":
@@ -666,7 +693,7 @@ class BayesianInferencePlot:
666693
return (xaxis, y)
667694
668695
def SVI_plot(self, guide_dist, n_steps=2000):
669-
fig, ax = plt.subplots(figsize=(10, 6))
696+
fig, ax = plt.subplots()
670697
671698
# plot prior
672699
prior_sample = self.BayesianInferenceClass.show_prior(disp_plot=0)
@@ -696,10 +723,6 @@ class BayesianInferencePlot:
696723
label=f"Posterior with $n={n}$",
697724
)
698725
ax.legend(loc="upper left")
699-
ax.set_title(
700-
f"SVI density of posterior distributions with {guide_dist} guide",
701-
fontsize=15,
702-
)
703726
plt.xlim(0, 1)
704727
plt.show()
705728
```
@@ -732,6 +755,13 @@ For the same Beta prior, we shall
732755
Let's start with the analytical method that we described in this {doc}`prob_meaning`
733756
734757
```{code-cell} ipython3
758+
---
759+
mystnb:
760+
figure:
761+
caption: |
762+
Analytical Beta prior and posterior
763+
name: fig_analytical
764+
---
735765
# first examine Beta prior
736766
BETA = BayesianInference(param=(5, 5), name_dist="beta")
737767
@@ -741,7 +771,7 @@ BETA_plot = BayesianInferencePlot(true_theta, num_list, BETA)
741771
xaxis = np.linspace(0, 1, 1000)
742772
y_prior = st.beta.pdf(xaxis, 5, 5)
743773
744-
fig, ax = plt.subplots(figsize=(10, 6))
774+
fig, ax = plt.subplots()
745775
# plot analytical beta prior
746776
ax.plot(xaxis, y_prior, label="Analytical Beta prior", color="#4C4E52")
747777
@@ -758,7 +788,6 @@ for id, n in enumerate(N_list):
758788
label=f"Analytical Beta posterior with $n={n}$",
759789
)
760790
ax.legend(loc="upper left")
761-
ax.set_title("Analytical Beta prior and posterior", fontsize=15)
762791
plt.xlim(0, 1)
763792
plt.show()
764793
```
@@ -772,8 +801,8 @@ We'll do this for both MCMC and VI.
772801
mystnb:
773802
figure:
774803
caption: |
775-
MCMC sampling density of posterior distributions
776-
name: mcmc
804+
MCMC density with Beta prior
805+
name: fig_mcmc_beta
777806
---
778807
779808
BayesianInferencePlot(true_theta, num_list, BETA).MCMC_plot(
@@ -782,6 +811,14 @@ BayesianInferencePlot(true_theta, num_list, BETA).MCMC_plot(
782811
```
783812
784813
```{code-cell} ipython3
814+
---
815+
mystnb:
816+
figure:
817+
caption: |
818+
SVI density with Beta guide
819+
name: fig_svi_beta
820+
---
821+
785822
BayesianInferencePlot(true_theta, num_list, BETA).SVI_plot(
786823
guide_dist="beta", n_steps=SVI_num_steps
787824
)
@@ -825,7 +862,7 @@ We first initialize the `BayesianInference` classes and then can directly call `
825862
STD_UNIFORM = BayesianInference(param=(0, 1), name_dist="uniform")
826863
UNIFORM = BayesianInference(param=(0.2, 0.7), name_dist="uniform")
827864
828-
# Try truncated lognormal
865+
# Try truncated log normal
829866
LOGNORMAL = BayesianInference(param=(0, 2), name_dist="lognormal")
830867
831868
# Try Von Mises
@@ -836,6 +873,13 @@ LAPLACE = BayesianInference(param=(0.5, 0.07), name_dist="laplace")
836873
```
837874
838875
```{code-cell} ipython3
876+
---
877+
mystnb:
878+
figure:
879+
caption: |
880+
MCMC density with uniform prior
881+
name: fig_mcmc_uniform
882+
---
839883
# Uniform
840884
example_CLASS = STD_UNIFORM
841885
print(
@@ -861,7 +905,14 @@ Consequently, the posterior cannot put positive probability above $\overline{\th
861905
Note how when the true data-generating $\theta$ is located at $0.8$ as it is here, when $n$ gets large, the posterior concentrates on the upper bound of the support of the prior, $0.7$ here.
862906
863907
```{code-cell} ipython3
864-
# Log Normal
908+
---
909+
mystnb:
910+
figure:
911+
caption: |
912+
MCMC density with log normal prior
913+
name: fig_mcmc_lognormal
914+
---
915+
# log normal
865916
example_CLASS = LOGNORMAL
866917
print(
867918
f"=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}"
@@ -872,7 +923,14 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(
872923
```
873924
874925
```{code-cell} ipython3
875-
# Von Mises
926+
---
927+
mystnb:
928+
figure:
929+
caption: |
930+
MCMC density with von Mises prior
931+
name: fig_mcmc_vonmises
932+
---
933+
# von Mises
876934
example_CLASS = VONMISES
877935
print(
878936
f"=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}"
@@ -884,6 +942,13 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(
884942
```
885943
886944
```{code-cell} ipython3
945+
---
946+
mystnb:
947+
figure:
948+
caption: |
949+
MCMC density with Laplace prior
950+
name: fig_mcmc_laplace
951+
---
887952
# Laplace
888953
example_CLASS = LAPLACE
889954
print(
@@ -894,15 +959,23 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(
894959
)
895960
```
896961
962+
### VI
963+
897964
To get more accuracy we will now increase the number of steps for Variational Inference (VI)
898965
899966
```{code-cell} ipython3
900967
SVI_num_steps = 50000
901968
```
902-
903-
#### VI with a truncated Normal guide
969+
#### VI with a truncated normal guide
904970
905971
```{code-cell} ipython3
972+
---
973+
mystnb:
974+
figure:
975+
caption: |
976+
SVI density with uniform prior and normal guide
977+
name: fig_svi_uniform_normal
978+
---
906979
# Uniform
907980
example_CLASS = BayesianInference(param=(0, 1), name_dist="uniform")
908981
print(
@@ -914,7 +987,14 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(
914987
```
915988
916989
```{code-cell} ipython3
917-
# Lognormal
990+
---
991+
mystnb:
992+
figure:
993+
caption: |
994+
SVI density with log normal prior and normal guide
995+
name: fig_svi_lognormal_normal
996+
---
997+
# log normal
918998
example_CLASS = LOGNORMAL
919999
print(
9201000
f"=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}"
@@ -925,6 +1005,13 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(
9251005
```
9261006
9271007
```{code-cell} ipython3
1008+
---
1009+
mystnb:
1010+
figure:
1011+
caption: |
1012+
SVI density with Laplace prior and normal guide
1013+
name: fig_svi_laplace_normal
1014+
---
9281015
# Laplace
9291016
example_CLASS = LAPLACE
9301017
print(
@@ -938,7 +1025,14 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(
9381025
#### Variational inference with a Beta guide distribution
9391026
9401027
```{code-cell} ipython3
941-
# Uniform
1028+
---
1029+
mystnb:
1030+
figure:
1031+
caption: |
1032+
SVI density with uniform prior and Beta guide
1033+
name: fig_svi_uniform_beta
1034+
---
1035+
# uniform
9421036
example_CLASS = STD_UNIFORM
9431037
print(
9441038
f"=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}"
@@ -949,7 +1043,14 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(
9491043
```
9501044
9511045
```{code-cell} ipython3
952-
# log Normal
1046+
---
1047+
mystnb:
1048+
figure:
1049+
caption: |
1050+
SVI density with log normal prior and Beta guide
1051+
name: fig_svi_lognormal_beta
1052+
---
1053+
# log normal
9531054
example_CLASS = LOGNORMAL
9541055
print(
9551056
f"=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}"
@@ -960,7 +1061,14 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(
9601061
```
9611062
9621063
```{code-cell} ipython3
963-
# Von Mises
1064+
# von Mises
1065+
---
1066+
mystnb:
1067+
figure:
1068+
caption: |
1069+
SVI density with von Mises prior and Beta guide
1070+
name: fig_svi_vonmises_beta
1071+
---
9641072
example_CLASS = VONMISES
9651073
print(
9661074
f"=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}"
@@ -972,6 +1080,13 @@ BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(
9721080
```
9731081
9741082
```{code-cell} ipython3
1083+
---
1084+
mystnb:
1085+
figure:
1086+
caption: |
1087+
SVI density with Laplace prior and Beta guide
1088+
name: fig_svi_laplace_beta
1089+
---
9751090
# Laplace
9761091
example_CLASS = LAPLACE
9771092
print(

0 commit comments

Comments
 (0)