Skip to content

Commit a4f7341

Browse files
committed
add 64 bit to see run time
1 parent ff7e42b commit a4f7341

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

lectures/optgrowth_fast.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ import jax
6565
import jax.numpy as jnp
6666
from typing import NamedTuple
6767
import quantecon as qe
68+
69+
jax.config.update("jax_enable_x64", True)
6870
```
6971

7072
## The model
@@ -97,7 +99,7 @@ As before, we will be able to compare with the true solutions
9799
We store primitives in a `NamedTuple` built for JAX and create a factory function to generate instances.
98100

99101
```{code-cell} python3
100-
class OptimalGrowthModelJAX(NamedTuple):
102+
class OptimalGrowthModel(NamedTuple):
101103
α: float # production parameter
102104
β: float # discount factor
103105
μ: float # shock location parameter
@@ -118,7 +120,7 @@ def create_optgrowth_model(α=0.4,
118120
shock_size=250,
119121
c_grid_size=200,
120122
seed=0):
121-
"""Factory function to create an OptimalGrowthModelJAX instance."""
123+
"""Factory function to create an OptimalGrowthModel instance."""
122124
123125
key = jax.random.PRNGKey(seed)
124126
y_grid = jnp.linspace(1e-5, grid_max, grid_size)
@@ -127,7 +129,7 @@ def create_optgrowth_model(α=0.4,
127129
128130
# Avoid endpoints 0 and 1 to keep feasibility and positivity.
129131
c_grid_frac = jnp.linspace(1e-6, 1.0 - 1e-6, c_grid_size)
130-
return OptimalGrowthModelJAX(α=α, β=β, μ=μ, s=s, γ=γ,
132+
return OptimalGrowthModel(α=α, β=β, μ=μ, s=s, γ=γ,
131133
y_grid=y_grid, shocks=shocks,
132134
c_grid_frac=c_grid_frac)
133135
```

0 commit comments

Comments
 (0)