You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The function below calculates the excess demand for given parameters
557
553
558
554
```{code-cell} ipython3
555
+
@jax.jit
559
556
def e(p, A, b, c):
560
-
return np.exp(-A @ p) + c - b * np.sqrt(p)
557
+
return jnp.exp(-A @ p) + c - b * jnp.sqrt(p)
561
558
```
562
559
563
560
Our default parameter values will be
@@ -581,36 +578,47 @@ A = \begin{bmatrix}
581
578
$$
582
579
583
580
```{code-cell} ipython3
584
-
A = np.array([[0.5, 0.4], [0.8, 0.2]])
585
-
b = np.ones(2)
586
-
c = np.ones(2)
581
+
A = jnp.array([[0.5, 0.4], [0.8, 0.2]])
582
+
b = jnp.ones(2)
583
+
c = jnp.ones(2)
587
584
```
588
585
589
586
At a price level of $p = (1, 0.5)$, the excess demand is
590
587
591
588
```{code-cell} ipython3
592
-
ex_demand = e((1.0, 0.5), A, b, c)
589
+
p = jnp.array([1, 0.5])
590
+
ex_demand = e(p, A, b, c)
593
591
594
592
print(
595
593
f"The excess demand for good 0 is {ex_demand[0]:.3f} \n"
596
594
f"The excess demand for good 1 is {ex_demand[1]:.3f}"
597
595
)
598
596
```
599
597
598
+
To increase the efficiency of computation, we will use the power of vectorization using [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html). This is much faster than the python loops.
@@ -727,7 +735,7 @@ Now the solution is even more accurate (although, in this low-dimensional proble
727
735
728
736
```{code-cell} ipython3
729
737
p = solution.x
730
-
e_p = np.max(np.abs(e(p, A, b, c)))
738
+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
731
739
e_p.item()
732
740
```
733
741
@@ -747,35 +755,35 @@ This is a multivariate version of [](oneD-newton)
747
755
748
756
The iteration starts from some initial guess of the price vector $p_0$.
749
757
750
-
Here, instead of coding Jacobian by hand, we use the `jacobian()` function in the `autograd` library to auto-differentiate and calculate the Jacobian.
758
+
Here, instead of coding Jacobian by hand, we use the `jacobian()` function in the `jax` library to auto-differentiate and calculate the Jacobian.
751
759
752
760
With only slight modification, we can generalize [our previous attempt](first_newton_attempt) to multidimensional problems
753
761
754
762
```{code-cell} ipython3
755
763
def newton(f, x_0, tol=1e-5, max_iter=10):
756
764
x = x_0
757
-
q = lambda x: x - np.linalg.solve(jacobian(f)(x), f(x))
765
+
f_jac = jax.jacobian(f)
766
+
767
+
@jax.jit
768
+
def q(x):
769
+
return x - jnp.linalg.solve(f_jac(x), f(x))
770
+
758
771
error = tol + 1
759
772
n = 0
760
773
while error > tol:
761
774
n += 1
762
775
if n > max_iter:
763
776
raise Exception("Max iteration reached without convergence")
764
777
y = q(x)
765
-
if any(np.isnan(y)):
778
+
if any(jnp.isnan(y)):
766
779
raise Exception("Solution not found with NaN generated")
767
-
error = np.linalg.norm(x - y)
780
+
error = jnp.linalg.norm(x - y)
768
781
x = y
769
782
print(f"iteration {n}, error = {error:.5f}")
770
783
print("\n" + f"Result = {x} \n")
771
784
return x
772
785
```
773
786
774
-
```{code-cell} ipython3
775
-
def e(p, A, b, c):
776
-
return np.exp(-np.dot(A, p)) + c - b * np.sqrt(p)
777
-
```
778
-
779
787
We find the algorithm terminates in 4 steps
780
788
781
789
```{code-cell} ipython3
@@ -784,7 +792,7 @@ p = newton(lambda p: e(p, A, b, c), init_p)
784
792
```
785
793
786
794
```{code-cell} ipython3
787
-
e_p = np.max(np.abs(e(p, A, b, c)))
795
+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
788
796
e_p.item()
789
797
```
790
798
@@ -797,30 +805,29 @@ With the larger overhead, the speed is not better than the optimized `scipy` fun
797
805
798
806
Our next step is to investigate a large market with 3,000 goods.
799
807
800
-
A JAX version of this section using GPU accelerated linear algebra and
801
-
automatic differentiation is available [here](https://jax.quantecon.org/newtons_method.html#application)
802
808
803
809
The excess demand function is essentially the same, but now the matrix $A$ is $3000 \times 3000$ and the parameter vectors $b$ and $c$ are $3000 \times 1$.
804
810
805
811
```{code-cell} ipython3
806
812
dim = 3000
807
-
np.random.seed(123)
808
813
809
-
# Create a random matrix A and normalize the rows to sum to one
810
-
A = np.random.rand(dim, dim)
811
-
A = np.asarray(A)
812
-
s = np.sum(A, axis=0)
814
+
# Create JAX random key
815
+
key = jax.random.PRNGKey(123)
816
+
817
+
# Create a random matrix A and normalize the columns to sum to one
818
+
A = jax.random.uniform(key, (dim, dim))
819
+
s = jnp.sum(A, axis=0)
813
820
A = A / s
814
821
815
822
# Set up b and c
816
-
b = np.ones(dim)
817
-
c = np.ones(dim)
823
+
b = jnp.ones(dim)
824
+
c = jnp.ones(dim)
818
825
```
819
826
820
827
Here's our initial condition
821
828
822
829
```{code-cell} ipython3
823
-
init_p = np.ones(dim)
830
+
init_p = jnp.ones(dim)
824
831
```
825
832
826
833
```{code-cell} ipython3
@@ -829,24 +836,26 @@ p = newton(lambda p: e(p, A, b, c), init_p)
829
836
```
830
837
831
838
```{code-cell} ipython3
832
-
e_p = np.max(np.abs(e(p, A, b, c)))
839
+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
833
840
e_p.item()
834
841
```
835
842
836
843
With the same tolerance, we compare the runtime and accuracy of Newton's method to SciPy's `root` function
837
844
838
845
```{code-cell} ipython3
839
846
%%time
840
-
solution = root(lambda p: e(p, A, b, c),
841
-
init_p,
842
-
jac=lambda p: jacobian(e)(p, A, b, c),
843
-
method='hybr',
844
-
tol=1e-5)
847
+
solution = root(
848
+
lambda p: e(p, A, b, c),
849
+
init_p,
850
+
jac=lambda p: jax.jacobian(e)(p, A, b, c),
851
+
method="hybr",
852
+
tol=1e-5,
853
+
)
845
854
```
846
855
847
856
```{code-cell} ipython3
848
857
p = solution.x
849
-
e_p = np.max(np.abs(e(p, A, b, c)))
858
+
e_p = jnp.max(jnp.abs(e(p, A, b, c)))
850
859
e_p.item()
851
860
```
852
861
@@ -923,20 +932,21 @@ The result should converge to the [analytical solution](solved_k).
923
932
Let's first define the parameters for this problem
0 commit comments