Skip to content

Commit 7d00832

Browse files
committed
update to use JAX entirely
1 parent 9bead7e commit 7d00832

File tree

1 file changed

+183
-112
lines changed

1 file changed

+183
-112
lines changed

lectures/kesten_processes.md

Lines changed: 183 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.16.6
7+
jupytext_version: 1.17.2
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -56,14 +56,53 @@ Let's start with some imports:
5656

5757
```{code-cell} ipython3
5858
import matplotlib.pyplot as plt
59-
import numpy as np
6059
import quantecon as qe
6160
import yfinance as yf
61+
import jax
62+
import jax.numpy as jnp
63+
from jax import random, vmap, jit
64+
from functools import partial
65+
from typing import NamedTuple
6266
```
6367

6468
Additional technical background related to this lecture can be found in the
6569
monograph by {cite}`buraczewski2016stochastic`.
6670

71+
We will use the following general-purpose function for generating time series paths
72+
73+
```{code-cell} ipython3
74+
:tags: [hide-input]
75+
76+
@partial(jax.jit, static_argnames=['f', 'num_steps'])
77+
def generate_path(f, initial_state, num_steps, model, key):
78+
"""
79+
Generate a time series by repeatedly applying an update rule.
80+
Given a map f, initial state x_0, and model parameters θ, this
81+
function computes and returns the sequence {x_t}_{t=0}^{T-1} when
82+
x_{t+1} = f(x_t, t, θ)
83+
Args:
84+
f: Update function mapping (x_t, t, model, key) -> x_{t+1}
85+
initial_state: Initial state x_0
86+
num_steps: Number of time steps T to simulate
87+
model: Model parameters
88+
key: Random key for reproducible randomness
89+
Returns:
90+
Array of shape (dim(x), T) containing the time series path
91+
[x_0, x_1, x_2, ..., x_{T-1}]
92+
"""
93+
def update_wrapper(carry, t):
94+
"""Wrapper function that adapts f for use with JAX scan."""
95+
state, subkey = carry
96+
subkey, new_subkey = random.split(subkey)
97+
next_state = f(state, t, model, new_subkey)
98+
return (next_state, subkey), state
99+
100+
# Initial carry: (initial_state, key)
101+
init_carry = (initial_state, key)
102+
_, path = jax.lax.scan(update_wrapper, init_carry, jnp.arange(num_steps))
103+
return path.T
104+
```
105+
67106
## Kesten processes
68107

69108
```{index} single: Kesten processes; heavy tails
@@ -327,26 +366,49 @@ This leads to spikes in the time series, which fill out the extreme right hand t
327366
The spikes in the time series are visible in the following simulation, which generates of 10 paths when $a_t$ and $b_t$ are lognormal.
328367

329368
```{code-cell} ipython3
330-
μ = -0.5
331-
σ = 1.0
369+
class KestenModel(NamedTuple):
370+
"""Parameters for Kesten process X_{t+1} = a_{t+1} X_t + η_{t+1}"""
371+
μ: float = -0.5 # location parameter for log(a_t)
372+
σ: float = 1.0 # scale parameter for log(a_t)
332373
333374
334-
def kesten_ts(ts_length=100):
335-
x = np.zeros(ts_length)
336-
for t in range(ts_length - 1):
337-
a = np.exp(μ + σ * np.random.randn())
338-
b = np.exp(np.random.randn())
339-
x[t+1] = a * x[t] + b
340-
return x
375+
@jax.jit
376+
def kesten_update(current_x, time_step, model, key):
377+
"""
378+
Update function for Kesten process: X_{t+1} = a_{t+1} X_t + η_{t+1}
379+
"""
380+
# Split key for random number generation
381+
key_a, key_η = random.split(key, 2)
341382
383+
# Generate random shocks
384+
shock_a = random.normal(key_a)
385+
shock_η = random.normal(key_η)
386+
387+
# Compute a_t and η_t
388+
a = jnp.exp(model.μ + model.σ * shock_a)
389+
η = jnp.exp(shock_η)
390+
391+
# Kesten process update
392+
next_x = a * current_x + η
393+
394+
return next_x
342395
343396
fig, ax = plt.subplots()
344397
345398
num_paths = 10
346-
np.random.seed(12)
399+
model = KestenModel()
347400
348401
for i in range(num_paths):
349-
ax.plot(kesten_ts())
402+
key = random.PRNGKey(i)
403+
404+
path = generate_path(
405+
kesten_update,
406+
initial_state=0.0,
407+
num_steps=100,
408+
model=model,
409+
key=key
410+
)
411+
ax.plot(path)
350412
351413
ax.set(xlabel="time", ylabel="$X_t$")
352414
plt.show()
@@ -446,31 +508,55 @@ While the time path differs, you should see bursts of high volatility.
446508
Here is one solution:
447509

448510
```{code-cell} ipython3
449-
α_0 = 1e-5
450-
α_1 = 0.1
451-
β = 0.9
511+
class GARCHModel(NamedTuple):
512+
"""Parameters for GARCH(1,1) volatility model"""
513+
α_0: float = 1e-5 # constant term
514+
α_1: float = 0.1 # coefficient on lagged squared shock
515+
β: float = 0.9 # coefficient on lagged volatility
452516
453517
years = 15
454518
days = years * 250
455519
520+
@jax.jit
521+
def garch_update(current_state, time_step, model, key):
522+
"""Update function for GARCH(1,1) volatility and returns"""
523+
σ2_current, r_previous = current_state
524+
525+
# Split key for random number generation
526+
key_xi, key_zeta = random.split(key, 2)
527+
528+
# Generate random shocks
529+
ξ = random.normal(key_xi)
530+
ζ = random.normal(key_zeta)
531+
532+
# Update volatility
533+
σ2_next = model.α_0 + σ2_current * (model.α_1 * ξ**2 + model.β)
456534
457-
def garch_ts(ts_length=days):
458-
σ2 = 0
459-
r = np.zeros(ts_length)
460-
for t in range(ts_length - 1):
461-
ξ = np.random.randn()
462-
σ2 = α_0 + σ2 * (α_1 * ξ**2 + β)
463-
r[t] = np.sqrt(σ2) * np.random.randn()
464-
return r
535+
# Generate return
536+
r_current = jnp.sqrt(σ2_current) * ζ
465537
538+
return jnp.array([σ2_next, r_current])
466539
467540
fig, ax = plt.subplots()
468541
469-
np.random.seed(12)
542+
key = random.PRNGKey(0)
543+
model = GARCHModel()
470544
471-
ax.plot(garch_ts(), alpha=0.7)
545+
# Initial state
546+
initial_state = jnp.array([0.0, 0.0])
472547
473-
ax.set(xlabel="time", ylabel="$\\sigma_t^2$")
548+
path = generate_path(
549+
garch_update,
550+
initial_state=initial_state,
551+
num_steps=days,
552+
model=model,
553+
key=key
554+
)
555+
556+
# Extract and plot returns
557+
ax.plot(path[1, :], alpha=0.7)
558+
559+
ax.set(xlabel="time", ylabel="returns")
474560
plt.show()
475561
```
476562

@@ -667,108 +753,93 @@ s_init = 1.0 # initial condition for each firm
667753
:class: dropdown
668754
```
669755

670-
Here's one solution.
671-
First we generate the observations:
672-
673-
```{code-cell} ipython3
674-
import jax
675-
import jax.numpy as jnp
676-
from jax import random, vmap, jit
677-
678-
679-
def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init):
680-
"""Generate a single draw using JAX's scan for the time loop."""
681-
682-
def step_fn(carry, t):
683-
s, subkey = carry
684-
subkey, new_subkey = random.split(subkey)
685-
686-
# Generate random normal samples
687-
rand_normal = random.normal(new_subkey)
688-
689-
# Conditional logic using jnp.where
690-
# If s < s_bar: new_s = exp(μ_e + σ_e * randn())
691-
# Else: new_s = a * s + b
692-
# where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn())
693-
694-
# For the else branch, we need two random numbers
695-
subkey, key1, key2 = random.split(subkey, 3)
696-
rand_a = random.normal(key1)
697-
rand_b = random.normal(key2)
756+
Here's one solution using the `generate_path` framework.
698757

699-
# Calculate both possible new values
700-
new_s_under_bar = jnp.exp(μ_e + σ_e * rand_normal)
758+
First, we define the firm productivity update function:
701759

702-
a = jnp.exp(μ_a + σ_a * rand_a)
703-
b = jnp.exp(μ_b + σ_b * rand_b)
704-
new_s_over_bar = a * s + b
705-
706-
# Choose based on condition
707-
new_s = jnp.where(s < s_bar, new_s_under_bar, new_s_over_bar)
708-
709-
return (new_s, subkey), new_s
710-
711-
# Initial state: (s_init, key)
712-
init_carry = (s_init, key)
713-
714-
# Run the scan
715-
final_carry, _ = jax.lax.scan(step_fn, init_carry, jnp.arange(T))
716-
717-
# Return final s value
718-
return final_carry[0]
760+
```{code-cell} ipython3
761+
@jax.jit
762+
def firm_product_update(current_product, time_step, model, key):
763+
"""
764+
Update firm productivity according to entry/exit dynamics.
719765
766+
If productivity is below threshold: firm exits and is replaced by new entrant
767+
If productivity is above threshold: productivity evolves as Kesten process
768+
"""
769+
# Split key for random number generation
770+
key_a, key_η, key_e = random.split(key, 3)
771+
772+
# Generate random shocks
773+
shock_a = random.normal(key_a)
774+
shock_η = random.normal(key_η)
775+
shock_e = random.normal(key_e)
776+
777+
# Calculate potential new productivity values
778+
# If firm exits (s_t < s_bar): replaced by new entrant
779+
product_entrant = jnp.exp(model.μ_e + model.σ_e * shock_e)
780+
781+
# If firm continues (s_t >= s_bar): Kesten process dynamics
782+
a = jnp.exp(model.μ_a + model.σ_a * shock_a)
783+
η = jnp.exp(model.μ_b + model.σ_b * shock_η)
784+
product_incumbent = a * current_product + η
785+
786+
# Apply entry/exit rule
787+
new_product = jnp.where(
788+
current_product < model.s_bar,
789+
product_entrant,
790+
product_incumbent
791+
)
720792
721-
generate_single_draw = jax.jit(generate_single_draw, static_argnums=(8,))
793+
return new_product
722794
```
723795

724-
```{code-cell} ipython3
725-
# Use vmap to vectorize over the first argument (key)
726-
in_axes = [None] * 10
727-
in_axes[0] = 0
796+
Now we define a model container for parameters
728797

729-
vectorized_single_draw = vmap(
730-
generate_single_draw,
731-
in_axes=in_axes,
732-
)
798+
```{code-cell} ipython3
799+
class FirmDynamicsModel(NamedTuple):
800+
"""Parameters for firm dynamics with entry/exit"""
801+
μ_a: float = -0.5 # location parameter for log(a_t)
802+
σ_a: float = 0.1 # scale parameter for log(a_t)
803+
μ_b: float = 0.0 # location parameter for log(η_t)
804+
σ_b: float = 0.5 # scale parameter for log(η_t)
805+
μ_e: float = 0.0 # location parameter for log(e_t)
806+
σ_e: float = 0.5 # scale parameter for log(e_t)
807+
s_bar: float = 1.0 # exit threshold
733808
```
734809

810+
Now we generate multiple firm trajectories in parallel
811+
735812
```{code-cell} ipython3
736-
@jit
737-
def generate_draws(
738-
seed=0,
739-
μ_a=-0.5,
740-
σ_a=0.1,
741-
μ_b=0.0,
742-
σ_b=0.5,
743-
μ_e=0.0,
744-
σ_e=0.5,
745-
s_bar=1.0,
746-
T=500,
747-
M=1_000_000,
748-
s_init=1.0,
749-
):
750-
"""
751-
JAX-jit version of the generate_draws function.
752-
Returns:
753-
Array of M draws
754-
"""
755-
# Create M different random keys for parallel execution
813+
def generate_firm_distribution(model,
814+
seed=0, M=1_000_000, T=500, s_init=1.0):
815+
"""Generate distribution of firm productivities after T periods."""
816+
817+
# Create random keys for each firm
756818
key = random.PRNGKey(seed)
757819
keys = random.split(key, M)
758820
759-
draws = vectorized_single_draw(
760-
keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init
761-
)
821+
@jax.jit
822+
def single_firm_path(firm_key):
823+
# Generate path and return final productivity
824+
path = generate_path(
825+
firm_product_update,
826+
initial_state=s_init,
827+
num_steps=T,
828+
model=model,
829+
key=firm_key
830+
)
831+
return path[-1]
762832
763-
return draws
764-
```
833+
# Apply to all firms in parallel
834+
product_dist = vmap(single_firm_path)(keys)
765835
766-
```{code-cell} ipython3
767-
# Generate the observations
768-
data = generate_draws()
836+
return product_dist
837+
838+
# Generate the data
839+
data = generate_firm_distribution(FirmDynamicsModel())
769840
```
770841

771-
Now we produce the rank-size plot:
842+
Let's produce the rank-size plot
772843

773844
```{code-cell} ipython3
774845
fig, ax = plt.subplots()

0 commit comments

Comments
 (0)