@@ -65,6 +65,8 @@ import jax
6565import jax.numpy as jnp
6666from typing import NamedTuple
6767import quantecon as qe
68+
69+ jax.config.update("jax_enable_x64", True)
6870```
6971
7072## The model
@@ -97,7 +99,7 @@ As before, we will be able to compare with the true solutions
9799We store primitives in a ` NamedTuple ` built for JAX and create a factory function to generate instances.
98100
99101``` {code-cell} python3
100- class OptimalGrowthModelJAX (NamedTuple):
102+ class OptimalGrowthModel (NamedTuple):
101103 α: float # production parameter
102104 β: float # discount factor
103105 μ: float # shock location parameter
@@ -118,7 +120,7 @@ def create_optgrowth_model(α=0.4,
118120 shock_size=250,
119121 c_grid_size=200,
120122 seed=0):
121- """Factory function to create an OptimalGrowthModelJAX instance."""
123+ """Factory function to create an OptimalGrowthModel instance."""
122124
123125 key = jax.random.PRNGKey(seed)
124126 y_grid = jnp.linspace(1e-5, grid_max, grid_size)
@@ -127,7 +129,7 @@ def create_optgrowth_model(α=0.4,
127129
128130 # Avoid endpoints 0 and 1 to keep feasibility and positivity.
129131 c_grid_frac = jnp.linspace(1e-6, 1.0 - 1e-6, c_grid_size)
130- return OptimalGrowthModelJAX (α=α, β=β, μ=μ, s=s, γ=γ,
132+ return OptimalGrowthModel (α=α, β=β, μ=μ, s=s, γ=γ,
131133 y_grid=y_grid, shocks=shocks,
132134 c_grid_frac=c_grid_frac)
133135```
0 commit comments