You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/ifp_advanced.md
+291-1Lines changed: 291 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -333,7 +333,7 @@ is just $\mathbb E R_t$.
333
333
334
334
We test the condition $\beta \mathbb E R_t < 1$ in the code below.
335
335
336
-
## Implementation
336
+
## Numba Implementation
337
337
338
338
We will assume that $R_t = \exp(a_r \zeta_t + b_r)$ where $a_r, b_r$
339
339
are constants and $\{ \zeta_t\}$ is IID standard normal.
@@ -583,6 +583,296 @@ The dashed line is the 45 degree line.
583
583
We can see from the figure that the dynamics will be stable --- assets do not
584
584
diverge even in the highest state.
585
585
586
+
## JAX Implementation
587
+
588
+
We now provide a JAX implementation of the model.
589
+
590
+
JAX is a high-performance numerical computing library that provides automatic differentiation and JIT compilation, with support for GPU/TPU acceleration.
591
+
592
+
First we need to import JAX and related libraries:
593
+
594
+
```{code-cell} ipython
595
+
import jax
596
+
import jax.numpy as jnp
597
+
from jax import vmap
598
+
from typing import NamedTuple
599
+
600
+
# Import jax.jit with a different name to avoid conflict with numba.jit
601
+
jax_jit = jax.jit
602
+
```
603
+
604
+
We enable 64-bit precision in JAX to ensure accurate results that match the Numba implementation:
605
+
606
+
```{code-cell} ipython
607
+
jax.config.update("jax_enable_x64", True)
608
+
```
609
+
610
+
Here's the JAX version of the IFP class using NamedTuple for compatibility with JAX's JIT compilation:
611
+
612
+
```{code-cell} ipython
613
+
class IFP_JAX(NamedTuple):
614
+
"""
615
+
A NamedTuple that stores primitives for the income fluctuation
616
+
problem, using JAX.
617
+
"""
618
+
γ: float
619
+
β: float
620
+
P: jnp.ndarray
621
+
a_r: float
622
+
b_r: float
623
+
a_y: float
624
+
b_y: float
625
+
s_grid: jnp.ndarray
626
+
η_draws: jnp.ndarray
627
+
ζ_draws: jnp.ndarray
628
+
629
+
630
+
def create_ifp_jax(γ=1.5,
631
+
β=0.96,
632
+
P=np.array([(0.9, 0.1),
633
+
(0.1, 0.9)]),
634
+
a_r=0.1,
635
+
b_r=0.0,
636
+
a_y=0.2,
637
+
b_y=0.5,
638
+
shock_draw_size=50,
639
+
grid_max=10,
640
+
grid_size=100,
641
+
seed=1234):
642
+
"""
643
+
Create an instance of IFP_JAX with the given parameters.
644
+
"""
645
+
# Test stability assuming {R_t} is IID and adopts the lognormal
646
+
# specification given below. The test is then β E R_t < 1.
As we can see, the differences between the two implementations are extremely small (on the order of machine precision), confirming that both methods produce essentially identical results.
862
+
863
+
The tiny differences arise from:
864
+
- Different random number generators (NumPy vs JAX)
865
+
- Minor differences in floating-point operations order
866
+
- Different interpolation implementations
867
+
868
+
Despite these minor numerical differences, both implementations converge to the same optimal policy.
869
+
870
+
The JAX implementation provides several advantages:
871
+
872
+
1.**GPU/TPU acceleration**: JAX can automatically utilize GPU/TPU hardware for faster computation
873
+
2.**Automatic differentiation**: JAX provides automatic differentiation, which can be useful for sensitivity analysis
874
+
3.**Functional programming**: JAX encourages a functional style that can be easier to reason about and parallelize
0 commit comments