@@ -125,7 +125,7 @@ for μ in [1, 5, 10]:
125125 )
126126
127127ax.grid()
128- ax.set_xlabel("$y$", fontsize=14)
128+ ax.set_xlabel(r "$y$", fontsize=14)
129129ax.set_ylabel(r"$f(y \mid \mu)$", fontsize=14)
130130ax.axis(xmin=0, ymin=0)
131131ax.legend(fontsize=14)
@@ -284,8 +284,8 @@ def plot_joint_poisson(μ=7, y_n=20):
284284 ax = fig.add_subplot(111, projection="3d")
285285 ax.plot_surface(X, Y, Z.T, cmap="terrain", alpha=0.6)
286286 ax.scatter(X, Y, Z.T, color="black", alpha=0.5, linewidths=1)
287- ax.set(xlabel="$y_1$", ylabel="$y_2$")
288- ax.set_zlabel("$f(y_1, y_2)$", labelpad=10)
287+ ax.set(xlabel=r "$y_1$", ylabel=r "$y_2$")
288+ ax.set_zlabel(r "$f(y_1, y_2)$", labelpad=10)
289289 plt.show()
290290
291291
@@ -610,7 +610,7 @@ for β in [7, 8.5, 9.5, 10]:
610610 m, c = find_tangent(β)
611611 y = m * β_line + c
612612 ax.plot(β_line, y, "-", c="purple", alpha=0.8)
613- ax.text(β + 2.05, y[-1], f "$G({β}) = {abs(m):.0f}$", fontsize=12)
613+ ax.text(β + 2.05, y[-1], rf "$G({β}) = {abs(m):.0f}$", fontsize=12)
614614 ax.vlines(β, -24, logL(β), linestyles="--", alpha=0.5)
615615 ax.hlines(logL(β), 6, β, linestyles="--", alpha=0.5)
616616
@@ -646,7 +646,9 @@ X = jnp.array([[1, 2, 5], [1, 1, 3], [1, 4, 2], [1, 5, 2], [1, 3, 1]])
646646
647647y = jnp.array([1, 0, 1, 1, 0])
648648
649- stats_poisson = Poisson(y.__array__(), X.__array__()).fit()
649+ y_numpy = y.__array__()
650+ X_numpy = X.__array__()
651+ stats_poisson = Poisson(y_numpy, X_numpy).fit()
650652print(stats_poisson.summary())
651653```
652654
@@ -991,7 +993,9 @@ newton_raphson(prob, β)
991993``` {code-cell} ipython3
992994# Use statsmodels to verify results
993995# Note: use __array__() method to convert jax to numpy arrays
994- print(Probit(y.__array__(), X.__array__()).fit().summary())
996+ y_numpy = y.__array__()
997+ X_numpy = X.__array__()
998+ print(Probit(y_numpy, X_numpy).fit().summary())
995999```
9961000
9971001``` {solution-end}
0 commit comments