You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -50,14 +50,18 @@ We will use the following imports:
50
50
51
51
```{code-cell} ipython3
52
52
import matplotlib.pyplot as plt
53
-
import numpy as np
54
-
from numba import jit, float64
55
-
from numba.experimental import jitclass
53
+
import jax
54
+
import jax.numpy as jnp
55
+
from typing import NamedTuple
56
+
import quantecon as qe
57
+
58
+
# Set JAX to use CPU
59
+
jax.config.update('jax_platform_name', 'cpu')
56
60
```
57
61
58
-
## The Algorithm
62
+
## The algorithm
59
63
60
-
The model is the same as the McCall model with job separation we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
64
+
The model is the same as the McCall model with job separation that we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
61
65
62
66
We are going to start with the two Bellman equations we obtained for the model with job separation after {ref}`a simplifying transformation <ast_mcm>`.
63
67
@@ -82,16 +86,16 @@ v(w) = u(w) + \beta
82
86
83
87
The unknowns here are the function $v$ and the scalar $d$.
84
88
85
-
The difference between these and the pair of Bellman equations we previously worked on are
89
+
The differences between these and the pair of Bellman equations we previously worked on are
86
90
87
-
1.in {eq}`bell1mcmc`, what used to be a sum over a finite number of wage values is an integral over an infinite set.
91
+
1.In {eq}`bell1mcmc`, what used to be a sum over a finite number of wage values is an integral over an infinite set.
88
92
1. The function $v$ in {eq}`bell2mcmc` is defined over all $w \in \mathbb R_+$.
89
93
90
94
The function $q$ in {eq}`bell1mcmc` is the density of the wage offer distribution.
91
95
92
96
Its support is taken as equal to $\mathbb R_+$.
93
97
94
-
### Value Function Iteration
98
+
### Value function iteration
95
99
96
100
In theory, we should now proceed as follows:
97
101
@@ -106,12 +110,12 @@ The iterates of the value function can neither be calculated exactly nor stored
106
110
107
111
To see the issue, consider {eq}`bell2mcmc`.
108
112
109
-
Even if $v$ is a known function, the only way to store its update $v'$
113
+
Even if $v$ is a known function, the only way to store its update $v'$
110
114
is to record its value $v'(w)$ for every $w \in \mathbb R_+$.
111
115
112
116
Clearly, this is impossible.
113
117
114
-
### Fitted Value Function Iteration
118
+
### Fitted value function iteration
115
119
116
120
What we will do instead is use **fitted value function iteration**.
117
121
@@ -141,25 +145,25 @@ One good choice from both respects is continuous piecewise linear interpolation.
141
145
142
146
This method
143
147
144
-
1. combines well with value function iteration (see., e.g.,
148
+
1. combines well with value function iteration (see, e.g.,
145
149
{cite}`gordon1995stable` or {cite}`stachurski2008continuous`) and
146
150
1. preserves useful shape properties such as monotonicity and concavity/convexity.
147
151
148
-
Linear interpolation will be implemented using [numpy.interp](https://numpy.org/doc/stable/reference/generated/numpy.interp.html).
152
+
Linear interpolation will be implemented using JAX's interpolation function `jnp.interp`.
149
153
150
154
The next figure illustrates piecewise linear interpolation of an arbitrary
151
155
function on grid points $0, 0.2, 0.4, 0.6, 0.8, 1$.
152
156
153
157
```{code-cell} python3
154
158
def f(x):
155
-
y1 = 2 * np.cos(6 * x) + np.sin(14 * x)
159
+
y1 = 2 * jnp.cos(6 * x) + jnp.sin(14 * x)
156
160
return y1 + 2.5
157
161
158
-
c_grid = np.linspace(0, 1, 6)
159
-
f_grid = np.linspace(0, 1, 150)
162
+
c_grid = jnp.linspace(0, 1, 6)
163
+
f_grid = jnp.linspace(0, 1, 150)
160
164
161
165
def Af(x):
162
-
return np.interp(x, c_grid, f(c_grid))
166
+
return jnp.interp(x, c_grid, f(c_grid))
163
167
164
168
fig, ax = plt.subplots()
165
169
@@ -175,123 +179,128 @@ plt.show()
175
179
176
180
## Implementation
177
181
178
-
The first step is to build a jitted class for the McCall model with separation and
179
-
a continuous wage offer distribution.
182
+
The first step is to build a JAX-compatible structure for the McCall model with separation and a continuous wage offer distribution.
180
183
181
184
We will take the utility function to be the log function for this application, with $u(c) = \ln c$.
182
185
183
186
We will adopt the lognormal distribution for wages, with $w = \exp(\mu + \sigma z)$
184
187
when $z$ is standard normal and $\mu, \sigma$ are parameters.
0 commit comments