@@ -475,7 +475,7 @@ class BayesianInference:
475475
476476 def beta_guide(self, data):
477477 """
478- Defines the candidate parametrized variational distribution that we train to approximate posterior with Pyro/Numpyro
478+ Defines the candidate parametrized variational distribution that we train to approximate posterior with numpyro
479479 Here we use parameterized beta
480480 """
481481 alpha_q = numpyro.param("alpha_q", 10, constraint=nconstraints.positive)
@@ -485,7 +485,7 @@ class BayesianInference:
485485
486486 def truncnormal_guide(self, data):
487487 """
488- Defines the candidate parametrized variational distribution that we train to approximate posterior with Pyro/Numpyro
488+ Defines the candidate parametrized variational distribution that we train to approximate posterior with numpyro
489489 Here we use truncated normal on [0,1]
490490 """
491491 loc = numpyro.param("loc", 0.5, constraint=nconstraints.interval(0.0, 1.0))
@@ -495,7 +495,6 @@ class BayesianInference:
495495 def SVI_init(self, guide_dist, lr=0.0005):
496496 """
497497 Initiate SVI training mode with Adam optimizer
498- NOTE: truncnormal_guide can only be used with numpyro solver
499498 """
500499 adam_params = {"lr": lr}
501500
@@ -653,7 +652,7 @@ class BayesianInferencePlot:
653652 color=self.colorlist[id - 1],
654653 label=f"Posterior with $n={n}$",
655654 )
656- ax.legend()
655+ ax.legend(loc="upper left" )
657656 ax.set_title("MCMC Sampling density of Posterior Distributions", fontsize=15)
658657 plt.xlim(0, 1)
659658 plt.show()
@@ -709,7 +708,7 @@ class BayesianInferencePlot:
709708 color=self.colorlist[id - 1],
710709 label=f"Posterior with $n={n}$",
711710 )
712- ax.legend()
711+ ax.legend(loc="upper left" )
713712 ax.set_title(
714713 f"SVI density of Posterior Distributions with {guide_dist} guide",
715714 fontsize=15,
@@ -772,7 +771,7 @@ for id, n in enumerate(N_list):
772771 color=colorlist[id - 1],
773772 label=f"Analytical Beta Posterior with $n={n}$",
774773 )
775- ax.legend()
774+ ax.legend(loc="upper left" )
776775ax.set_title("Analytical Beta Prior and Posterior", fontsize=15)
777776plt.xlim(0, 1)
778777plt.show()
0 commit comments