Skip to content

Commit c63f294

Browse files
committed
update lecture to use JAX
1 parent edb44f5 commit c63f294

File tree

1 file changed

+84
-75
lines changed

1 file changed

+84
-75
lines changed

lectures/newton_method.md

Lines changed: 84 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,20 @@ Then we apply Newton's method to multidimensional settings to solve
7373
market for equilibria with multiple goods.
7474

7575
At the end of the lecture, we leverage the power of automatic
76-
differentiation in [`autograd`](https://github.com/HIPS/autograd) to solve a very high-dimensional equilibrium problem
76+
differentiation in [`jax`](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html) to solve a very high-dimensional equilibrium problem
7777

78-
```{code-cell} ipython3
79-
:tags: [hide-output]
80-
81-
!pip install autograd
82-
```
8378

8479
We use the following imports in this lecture
8580

8681
```{code-cell} ipython3
8782
import matplotlib.pyplot as plt
8883
from typing import NamedTuple
8984
from scipy.optimize import root
90-
from autograd import jacobian
85+
import jax.numpy as jnp
86+
import jax
9187
92-
# Thinly-wrapped numpy to enable automatic differentiation
93-
import autograd.numpy as np
88+
# Enable 64-bit precision
89+
jax.config.update("jax_enable_x64", True)
9490
```
9591

9692
## Fixed point computation using Newton's method
@@ -172,7 +168,7 @@ Here is a function to provide a 45 degree plot of the dynamics.
172168
def plot_45(params, ax, fontsize=14):
173169
174170
k_min, k_max = 0.0, 3.0
175-
k_grid = np.linspace(k_min, k_max, 1200)
171+
k_grid = jnp.linspace(k_min, k_max, 1200)
176172
177173
# Plot the functions
178174
lb = r"$g(k) = sAk^{\alpha} + (1 - \delta)k$"
@@ -353,12 +349,12 @@ def plot_trajectories(
353349
ax2.plot(ks4, "-o", label="newton steps")
354350
355351
for ax in axes:
356-
ax.plot(k_star * np.ones(n), "k--")
352+
ax.plot(k_star * jnp.ones(n), "k--")
357353
ax.legend(fontsize=fs, frameon=False)
358354
ax.set_ylim(0.6, 3.2)
359355
ax.set_yticks((k_star,))
360356
ax.set_yticklabels(("$k^*$",), fontsize=fs)
361-
ax.set_xticks(np.linspace(0, 19, 20))
357+
ax.set_xticks(jnp.linspace(0, 19, 20))
362358
363359
plt.show()
364360
```
@@ -418,10 +414,12 @@ The following code implements the iteration [](oneD-newton)
418414
(first_newton_attempt)=
419415

420416
```{code-cell} ipython3
421-
def newton(f, Df, x_0, tol=1e-7, max_iter=100_000):
417+
def newton(f, x_0, tol=1e-7, max_iter=100_000):
422418
x = x_0
419+
Df = jax.grad(f)
423420
424421
# Implement the zero-finding formula
422+
@jax.jit
425423
def q(x):
426424
return x - f(x) / Df(x)
427425
@@ -432,10 +430,10 @@ def newton(f, Df, x_0, tol=1e-7, max_iter=100_000):
432430
if n > max_iter:
433431
raise Exception("Max iteration reached without convergence")
434432
y = q(x)
435-
error = np.abs(x - y)
433+
error = jnp.abs(x - y)
436434
x = y
437435
print(f"iteration {n}, error = {error:.5f}")
438-
return x
436+
return x.item()
439437
```
440438

441439
Numerous libraries implement Newton's method in one dimension, including
@@ -459,9 +457,7 @@ Let's apply this idea to the Solow problem
459457

460458
```{code-cell} ipython3
461459
params = create_solow_params()
462-
k_star_approx_newton = newton(
463-
f=lambda x: g(x, params) - x, Df=lambda x: Dg(x, params) - 1, x_0=0.8
464-
)
460+
k_star_approx_newton = newton(f=lambda x: g(x, params) - x, x_0=0.8)
465461
```
466462

467463
```{code-cell} ipython3
@@ -556,8 +552,9 @@ $$
556552
The function below calculates the excess demand for given parameters
557553

558554
```{code-cell} ipython3
555+
@jax.jit
559556
def e(p, A, b, c):
560-
return np.exp(-A @ p) + c - b * np.sqrt(p)
557+
return jnp.exp(-A @ p) + c - b * jnp.sqrt(p)
561558
```
562559

563560
Our default parameter values will be
@@ -581,36 +578,47 @@ A = \begin{bmatrix}
581578
$$
582579

583580
```{code-cell} ipython3
584-
A = np.array([[0.5, 0.4], [0.8, 0.2]])
585-
b = np.ones(2)
586-
c = np.ones(2)
581+
A = jnp.array([[0.5, 0.4], [0.8, 0.2]])
582+
b = jnp.ones(2)
583+
c = jnp.ones(2)
587584
```
588585

589586
At a price level of $p = (1, 0.5)$, the excess demand is
590587

591588
```{code-cell} ipython3
592-
ex_demand = e((1.0, 0.5), A, b, c)
589+
p = jnp.array([1, 0.5])
590+
ex_demand = e(p, A, b, c)
593591
594592
print(
595593
f"The excess demand for good 0 is {ex_demand[0]:.3f} \n"
596594
f"The excess demand for good 1 is {ex_demand[1]:.3f}"
597595
)
598596
```
599597

598+
To increase the efficiency of computation, we will use the power of vectorization using [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html). This is much faster than the python loops.
599+
600+
```{code-cell} ipython3
601+
# Create vectorization on the first axis of p.
602+
e_vectorized_p_1 = jax.vmap(e, in_axes=(0, None, None, None))
603+
# Create vectorization on the second axis of p.
604+
e_vectorized = jax.vmap(e_vectorized_p_1, in_axes=(0, None, None, None))
605+
```
606+
600607
Next we plot the two functions $e_0$ and $e_1$ on a grid of $(p_0, p_1)$ values, using contour surfaces and lines.
601608

602609
We will use the following function to build the contour plots
603610

604611
```{code-cell} ipython3
605612
def plot_excess_demand(ax, good=0, grid_size=100, grid_max=4, surface=True):
613+
p_grid = jnp.linspace(0, grid_max, grid_size)
614+
# Create meshgrid for all combinations of p_1 and p_2
615+
P1, P2 = jnp.meshgrid(p_grid, p_grid, indexing="ij")
616+
# Stack to create array of shape (grid_size, grid_size, 2)
617+
P = jnp.stack([P1, P2], axis=-1)
606618
607-
# Create a 100x100 grid
608-
p_grid = np.linspace(0, grid_max, grid_size)
609-
z = np.empty((100, 100))
610-
611-
for i, p_1 in enumerate(p_grid):
612-
for j, p_2 in enumerate(p_grid):
613-
z[i, j] = e((p_1, p_2), A, b, c)[good]
619+
# Compute all values at once using vectorized function
620+
z_full = e_vectorized(P, A, b, c)
621+
z = z_full[:, :, good]
614622
615623
if surface:
616624
cs1 = ax.contourf(p_grid, p_grid, z.T, alpha=0.5)
@@ -662,7 +670,7 @@ To solve for $p^*$ more precisely, we use a zero-finding algorithm from `scipy.o
662670
We supply $p = (1, 1)$ as our initial guess.
663671

664672
```{code-cell} ipython3
665-
init_p = np.ones(2)
673+
init_p = jnp.ones(2)
666674
```
667675

668676
This uses the [modified Powell method](https://docs.scipy.org/doc/scipy/reference/optimize.root-hybr.html#optimize-root-hybr) to find the zero
@@ -682,7 +690,7 @@ p
682690
This looks close to our guess from observing the figure. We can plug it back into $e$ to test that $e(p) \approx 0$:
683691

684692
```{code-cell} ipython3
685-
e_p = np.max(np.abs(e(p, A, b, c)))
693+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
686694
e_p.item()
687695
```
688696

@@ -708,12 +716,12 @@ def jacobian_e(p, A, b, c):
708716
p_0, p_1 = p
709717
a_00, a_01 = A[0, :]
710718
a_10, a_11 = A[1, :]
711-
j_00 = -a_00 * np.exp(-a_00 * p_0) - (b[0] / 2) * p_0 ** (-1 / 2)
712-
j_01 = -a_01 * np.exp(-a_01 * p_1)
713-
j_10 = -a_10 * np.exp(-a_10 * p_0)
714-
j_11 = -a_11 * np.exp(-a_11 * p_1) - (b[1] / 2) * p_1 ** (-1 / 2)
719+
j_00 = -a_00 * jnp.exp(-a_00 * p_0) - (b[0] / 2) * p_0 ** (-1 / 2)
720+
j_01 = -a_01 * jnp.exp(-a_01 * p_1)
721+
j_10 = -a_10 * jnp.exp(-a_10 * p_0)
722+
j_11 = -a_11 * jnp.exp(-a_11 * p_1) - (b[1] / 2) * p_1 ** (-1 / 2)
715723
J = [[j_00, j_01], [j_10, j_11]]
716-
return np.array(J)
724+
return jnp.array(J)
717725
```
718726

719727
```{code-cell} ipython3
@@ -727,7 +735,7 @@ Now the solution is even more accurate (although, in this low-dimensional proble
727735

728736
```{code-cell} ipython3
729737
p = solution.x
730-
e_p = np.max(np.abs(e(p, A, b, c)))
738+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
731739
e_p.item()
732740
```
733741

@@ -747,35 +755,35 @@ This is a multivariate version of [](oneD-newton)
747755

748756
The iteration starts from some initial guess of the price vector $p_0$.
749757

750-
Here, instead of coding Jacobian by hand, we use the `jacobian()` function in the `autograd` library to auto-differentiate and calculate the Jacobian.
758+
Here, instead of coding Jacobian by hand, we use the `jacobian()` function in the `jax` library to auto-differentiate and calculate the Jacobian.
751759

752760
With only slight modification, we can generalize [our previous attempt](first_newton_attempt) to multidimensional problems
753761

754762
```{code-cell} ipython3
755763
def newton(f, x_0, tol=1e-5, max_iter=10):
756764
x = x_0
757-
q = lambda x: x - np.linalg.solve(jacobian(f)(x), f(x))
765+
f_jac = jax.jacobian(f)
766+
767+
@jax.jit
768+
def q(x):
769+
return x - jnp.linalg.solve(f_jac(x), f(x))
770+
758771
error = tol + 1
759772
n = 0
760773
while error > tol:
761774
n += 1
762775
if n > max_iter:
763776
raise Exception("Max iteration reached without convergence")
764777
y = q(x)
765-
if any(np.isnan(y)):
778+
if any(jnp.isnan(y)):
766779
raise Exception("Solution not found with NaN generated")
767-
error = np.linalg.norm(x - y)
780+
error = jnp.linalg.norm(x - y)
768781
x = y
769782
print(f"iteration {n}, error = {error:.5f}")
770783
print("\n" + f"Result = {x} \n")
771784
return x
772785
```
773786

774-
```{code-cell} ipython3
775-
def e(p, A, b, c):
776-
return np.exp(-np.dot(A, p)) + c - b * np.sqrt(p)
777-
```
778-
779787
We find the algorithm terminates in 4 steps
780788

781789
```{code-cell} ipython3
@@ -784,7 +792,7 @@ p = newton(lambda p: e(p, A, b, c), init_p)
784792
```
785793

786794
```{code-cell} ipython3
787-
e_p = np.max(np.abs(e(p, A, b, c)))
795+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
788796
e_p.item()
789797
```
790798

@@ -797,30 +805,29 @@ With the larger overhead, the speed is not better than the optimized `scipy` fun
797805

798806
Our next step is to investigate a large market with 3,000 goods.
799807

800-
A JAX version of this section using GPU accelerated linear algebra and
801-
automatic differentiation is available [here](https://jax.quantecon.org/newtons_method.html#application)
802808

803809
The excess demand function is essentially the same, but now the matrix $A$ is $3000 \times 3000$ and the parameter vectors $b$ and $c$ are $3000 \times 1$.
804810

805811
```{code-cell} ipython3
806812
dim = 3000
807-
np.random.seed(123)
808813
809-
# Create a random matrix A and normalize the rows to sum to one
810-
A = np.random.rand(dim, dim)
811-
A = np.asarray(A)
812-
s = np.sum(A, axis=0)
814+
# Create JAX random key
815+
key = jax.random.PRNGKey(123)
816+
817+
# Create a random matrix A and normalize the columns to sum to one
818+
A = jax.random.uniform(key, (dim, dim))
819+
s = jnp.sum(A, axis=0)
813820
A = A / s
814821
815822
# Set up b and c
816-
b = np.ones(dim)
817-
c = np.ones(dim)
823+
b = jnp.ones(dim)
824+
c = jnp.ones(dim)
818825
```
819826

820827
Here's our initial condition
821828

822829
```{code-cell} ipython3
823-
init_p = np.ones(dim)
830+
init_p = jnp.ones(dim)
824831
```
825832

826833
```{code-cell} ipython3
@@ -829,24 +836,26 @@ p = newton(lambda p: e(p, A, b, c), init_p)
829836
```
830837

831838
```{code-cell} ipython3
832-
e_p = np.max(np.abs(e(p, A, b, c)))
839+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
833840
e_p.item()
834841
```
835842

836843
With the same tolerance, we compare the runtime and accuracy of Newton's method to SciPy's `root` function
837844

838845
```{code-cell} ipython3
839846
%%time
840-
solution = root(lambda p: e(p, A, b, c),
841-
init_p,
842-
jac=lambda p: jacobian(e)(p, A, b, c),
843-
method='hybr',
844-
tol=1e-5)
847+
solution = root(
848+
lambda p: e(p, A, b, c),
849+
init_p,
850+
jac=lambda p: jax.jacobian(e)(p, A, b, c),
851+
method="hybr",
852+
tol=1e-5,
853+
)
845854
```
846855

847856
```{code-cell} ipython3
848857
p = solution.x
849-
e_p = np.max(np.abs(e(p, A, b, c)))
858+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
850859
e_p.item()
851860
```
852861

@@ -923,20 +932,21 @@ The result should converge to the [analytical solution](solved_k).
923932
Let's first define the parameters for this problem
924933

925934
```{code-cell} ipython3
926-
A = np.array([[2.0, 3.0, 3.0], [2.0, 4.0, 2.0], [1.0, 5.0, 1.0]])
935+
A = jnp.array([[2.0, 3.0, 3.0], [2.0, 4.0, 2.0], [1.0, 5.0, 1.0]])
927936
928937
s = 0.2
929938
α = 0.5
930939
δ = 0.8
931940
932-
initLs = [np.ones(3), np.array([3.0, 5.0, 5.0]), np.repeat(50.0, 3)]
941+
initLs = [jnp.ones(3), jnp.array([3.0, 5.0, 5.0]), jnp.repeat(50.0, 3)]
933942
```
934943

935944
Then define the multivariate version of the formula for the [law of motion of capital](motion_law)
936945

937946
```{code-cell} ipython3
947+
@jax.jit
938948
def multivariate_solow(k, A=A, s=s, α=α, δ=δ):
939-
return s * np.dot(A, k**α) + (1 - δ) * k
949+
return s * jnp.dot(A, k**α) + (1 - δ) * k
940950
```
941951

942952
Let's run through each starting value and see the output
@@ -966,15 +976,15 @@ Note the error is very small.
966976
We can also test our results on the known solution
967977

968978
```{code-cell} ipython3
969-
A = np.array([[2.0, 0.0, 0.0],
979+
A = jnp.array([[2.0, 0.0, 0.0],
970980
[0.0, 2.0, 0.0],
971981
[0.0, 0.0, 2.0]])
972982
973983
s = 0.3
974984
α = 0.3
975985
δ = 0.4
976986
977-
init = np.repeat(1.0, 3)
987+
init = jnp.repeat(1.0, 3)
978988
979989
980990
%time k = newton(lambda k: multivariate_solow(k, A=A, s=s, α=α, δ=δ) - k, \
@@ -1045,12 +1055,11 @@ Set the tolerance to $1e-15$ for more accurate output.
10451055
Define parameters and initial values
10461056

10471057
```{code-cell} ipython3
1048-
A = np.array([[0.2, 0.1, 0.7], [0.3, 0.2, 0.5], [0.1, 0.8, 0.1]])
1049-
1050-
b = np.array([1.0, 1.0, 1.0])
1051-
c = np.array([1.0, 1.0, 1.0])
1058+
A = jnp.array([[0.2, 0.1, 0.7], [0.3, 0.2, 0.5], [0.1, 0.8, 0.1]])
1059+
b = jnp.array([1.0, 1.0, 1.0])
1060+
c = jnp.array([1.0, 1.0, 1.0])
10521061
1053-
initLs = [np.repeat(5.0, 3), np.ones(3), np.array([4.5, 0.1, 4.0])]
1062+
initLs = [jnp.repeat(5.0, 3), jnp.ones(3), jnp.array([4.5, 0.1, 4.0])]
10541063
```
10551064

10561065
Let's run through each initial guess and check the output

0 commit comments

Comments
 (0)