@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.16.6
7+ jupytext_version : 1.17.2
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -56,14 +56,53 @@ Let's start with some imports:
5656
5757``` {code-cell} ipython3
5858import matplotlib.pyplot as plt
59- import numpy as np
6059import quantecon as qe
6160import yfinance as yf
61+ import jax
62+ import jax.numpy as jnp
63+ from jax import random, vmap, jit
64+ from functools import partial
65+ from typing import NamedTuple
6266```
6367
6468Additional technical background related to this lecture can be found in the
6569monograph by {cite}` buraczewski2016stochastic ` .
6670
71+ We will use the following general-purpose function for generating time series paths
72+
73+ ``` {code-cell} ipython3
74+ :tags: [hide-input]
75+
76+ @partial(jax.jit, static_argnames=['f', 'num_steps'])
77+ def generate_path(f, initial_state, num_steps, model, key):
78+ """
79+ Generate a time series by repeatedly applying an update rule.
80+ Given a map f, initial state x_0, and model parameters θ, this
81+ function computes and returns the sequence {x_t}_{t=0}^{T-1} when
82+ x_{t+1} = f(x_t, t, θ)
83+ Args:
84+ f: Update function mapping (x_t, t, model, key) -> x_{t+1}
85+ initial_state: Initial state x_0
86+ num_steps: Number of time steps T to simulate
87+ model: Model parameters
88+ key: Random key for reproducible randomness
89+ Returns:
90+ Array of shape (dim(x), T) containing the time series path
91+ [x_0, x_1, x_2, ..., x_{T-1}]
92+ """
93+ def update_wrapper(carry, t):
94+ """Wrapper function that adapts f for use with JAX scan."""
95+ state, subkey = carry
96+ subkey, new_subkey = random.split(subkey)
97+ next_state = f(state, t, model, new_subkey)
98+ return (next_state, subkey), state
99+
100+ # Initial carry: (initial_state, key)
101+ init_carry = (initial_state, key)
102+ _, path = jax.lax.scan(update_wrapper, init_carry, jnp.arange(num_steps))
103+ return path.T
104+ ```
105+
67106## Kesten processes
68107
69108``` {index} single: Kesten processes; heavy tails
@@ -327,26 +366,49 @@ This leads to spikes in the time series, which fill out the extreme right hand t
327366The spikes in the time series are visible in the following simulation, which generates of 10 paths when $a_t$ and $b_t$ are lognormal.
328367
329368``` {code-cell} ipython3
330- μ = -0.5
331- σ = 1.0
369+ class KestenModel(NamedTuple):
370+ """Parameters for Kesten process X_{t+1} = a_{t+1} X_t + η_{t+1}"""
371+ μ: float = -0.5 # location parameter for log(a_t)
372+ σ: float = 1.0 # scale parameter for log(a_t)
332373
333374
334- def kesten_ts(ts_length=100):
335- x = np.zeros(ts_length)
336- for t in range(ts_length - 1):
337- a = np.exp(μ + σ * np.random.randn())
338- b = np.exp(np.random.randn())
339- x[t+1] = a * x[t] + b
340- return x
375+ @jax.jit
376+ def kesten_update(current_x, time_step, model, key):
377+ """
378+ Update function for Kesten process: X_{t+1} = a_{t+1} X_t + η_{t+1}
379+ """
380+ # Split key for random number generation
381+ key_a, key_η = random.split(key, 2)
341382
383+ # Generate random shocks
384+ shock_a = random.normal(key_a)
385+ shock_η = random.normal(key_η)
386+
387+ # Compute a_t and η_t
388+ a = jnp.exp(model.μ + model.σ * shock_a)
389+ η = jnp.exp(shock_η)
390+
391+ # Kesten process update
392+ next_x = a * current_x + η
393+
394+ return next_x
342395
343396fig, ax = plt.subplots()
344397
345398num_paths = 10
346- np.random.seed(12 )
399+ model = KestenModel( )
347400
348401for i in range(num_paths):
349- ax.plot(kesten_ts())
402+ key = random.PRNGKey(i)
403+
404+ path = generate_path(
405+ kesten_update,
406+ initial_state=0.0,
407+ num_steps=100,
408+ model=model,
409+ key=key
410+ )
411+ ax.plot(path)
350412
351413ax.set(xlabel="time", ylabel="$X_t$")
352414plt.show()
@@ -446,31 +508,55 @@ While the time path differs, you should see bursts of high volatility.
446508Here is one solution:
447509
448510``` {code-cell} ipython3
449- α_0 = 1e-5
450- α_1 = 0.1
451- β = 0.9
511+ class GARCHModel(NamedTuple):
512+ """Parameters for GARCH(1,1) volatility model"""
513+ α_0: float = 1e-5 # constant term
514+ α_1: float = 0.1 # coefficient on lagged squared shock
515+ β: float = 0.9 # coefficient on lagged volatility
452516
453517years = 15
454518days = years * 250
455519
520+ @jax.jit
521+ def garch_update(current_state, time_step, model, key):
522+ """Update function for GARCH(1,1) volatility and returns"""
523+ σ2_current, r_previous = current_state
524+
525+ # Split key for random number generation
526+ key_xi, key_zeta = random.split(key, 2)
527+
528+ # Generate random shocks
529+ ξ = random.normal(key_xi)
530+ ζ = random.normal(key_zeta)
531+
532+ # Update volatility
533+ σ2_next = model.α_0 + σ2_current * (model.α_1 * ξ**2 + model.β)
456534
457- def garch_ts(ts_length=days):
458- σ2 = 0
459- r = np.zeros(ts_length)
460- for t in range(ts_length - 1):
461- ξ = np.random.randn()
462- σ2 = α_0 + σ2 * (α_1 * ξ**2 + β)
463- r[t] = np.sqrt(σ2) * np.random.randn()
464- return r
535+ # Generate return
536+ r_current = jnp.sqrt(σ2_current) * ζ
465537
538+ return jnp.array([σ2_next, r_current])
466539
467540fig, ax = plt.subplots()
468541
469- np.random.seed(12)
542+ key = random.PRNGKey(0)
543+ model = GARCHModel()
470544
471- ax.plot(garch_ts(), alpha=0.7)
545+ # Initial state
546+ initial_state = jnp.array([0.0, 0.0])
472547
473- ax.set(xlabel="time", ylabel="$\\sigma_t^2$")
548+ path = generate_path(
549+ garch_update,
550+ initial_state=initial_state,
551+ num_steps=days,
552+ model=model,
553+ key=key
554+ )
555+
556+ # Extract and plot returns
557+ ax.plot(path[1, :], alpha=0.7)
558+
559+ ax.set(xlabel="time", ylabel="returns")
474560plt.show()
475561```
476562
@@ -667,108 +753,93 @@ s_init = 1.0 # initial condition for each firm
667753:class: dropdown
668754```
669755
670- Here's one solution.
671- First we generate the observations:
672-
673- ``` {code-cell} ipython3
674- import jax
675- import jax.numpy as jnp
676- from jax import random, vmap, jit
677-
678-
679- def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init):
680- """Generate a single draw using JAX's scan for the time loop."""
681-
682- def step_fn(carry, t):
683- s, subkey = carry
684- subkey, new_subkey = random.split(subkey)
685-
686- # Generate random normal samples
687- rand_normal = random.normal(new_subkey)
688-
689- # Conditional logic using jnp.where
690- # If s < s_bar: new_s = exp(μ_e + σ_e * randn())
691- # Else: new_s = a * s + b
692- # where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn())
693-
694- # For the else branch, we need two random numbers
695- subkey, key1, key2 = random.split(subkey, 3)
696- rand_a = random.normal(key1)
697- rand_b = random.normal(key2)
756+ Here's one solution using the ` generate_path ` framework.
698757
699- # Calculate both possible new values
700- new_s_under_bar = jnp.exp(μ_e + σ_e * rand_normal)
758+ First, we define the firm productivity update function:
701759
702- a = jnp.exp(μ_a + σ_a * rand_a)
703- b = jnp.exp(μ_b + σ_b * rand_b)
704- new_s_over_bar = a * s + b
705-
706- # Choose based on condition
707- new_s = jnp.where(s < s_bar, new_s_under_bar, new_s_over_bar)
708-
709- return (new_s, subkey), new_s
710-
711- # Initial state: (s_init, key)
712- init_carry = (s_init, key)
713-
714- # Run the scan
715- final_carry, _ = jax.lax.scan(step_fn, init_carry, jnp.arange(T))
716-
717- # Return final s value
718- return final_carry[0]
760+ ``` {code-cell} ipython3
761+ @jax.jit
762+ def firm_product_update(current_product, time_step, model, key):
763+ """
764+ Update firm productivity according to entry/exit dynamics.
719765
766+ If productivity is below threshold: firm exits and is replaced by new entrant
767+ If productivity is above threshold: productivity evolves as Kesten process
768+ """
769+ # Split key for random number generation
770+ key_a, key_η, key_e = random.split(key, 3)
771+
772+ # Generate random shocks
773+ shock_a = random.normal(key_a)
774+ shock_η = random.normal(key_η)
775+ shock_e = random.normal(key_e)
776+
777+ # Calculate potential new productivity values
778+ # If firm exits (s_t < s_bar): replaced by new entrant
779+ product_entrant = jnp.exp(model.μ_e + model.σ_e * shock_e)
780+
781+ # If firm continues (s_t >= s_bar): Kesten process dynamics
782+ a = jnp.exp(model.μ_a + model.σ_a * shock_a)
783+ η = jnp.exp(model.μ_b + model.σ_b * shock_η)
784+ product_incumbent = a * current_product + η
785+
786+ # Apply entry/exit rule
787+ new_product = jnp.where(
788+ current_product < model.s_bar,
789+ product_entrant,
790+ product_incumbent
791+ )
720792
721- generate_single_draw = jax.jit(generate_single_draw, static_argnums=(8,))
793+ return new_product
722794```
723795
724- ``` {code-cell} ipython3
725- # Use vmap to vectorize over the first argument (key)
726- in_axes = [None] * 10
727- in_axes[0] = 0
796+ Now we define a model container for parameters
728797
729- vectorized_single_draw = vmap(
730- generate_single_draw,
731- in_axes=in_axes,
732- )
798+ ``` {code-cell} ipython3
799+ class FirmDynamicsModel(NamedTuple):
800+ """Parameters for firm dynamics with entry/exit"""
801+ μ_a: float = -0.5 # location parameter for log(a_t)
802+ σ_a: float = 0.1 # scale parameter for log(a_t)
803+ μ_b: float = 0.0 # location parameter for log(η_t)
804+ σ_b: float = 0.5 # scale parameter for log(η_t)
805+ μ_e: float = 0.0 # location parameter for log(e_t)
806+ σ_e: float = 0.5 # scale parameter for log(e_t)
807+ s_bar: float = 1.0 # exit threshold
733808```
734809
810+ Now we generate multiple firm trajectories in parallel
811+
735812``` {code-cell} ipython3
736- @jit
737- def generate_draws(
738- seed=0,
739- μ_a=-0.5,
740- σ_a=0.1,
741- μ_b=0.0,
742- σ_b=0.5,
743- μ_e=0.0,
744- σ_e=0.5,
745- s_bar=1.0,
746- T=500,
747- M=1_000_000,
748- s_init=1.0,
749- ):
750- """
751- JAX-jit version of the generate_draws function.
752- Returns:
753- Array of M draws
754- """
755- # Create M different random keys for parallel execution
813+ def generate_firm_distribution(model,
814+ seed=0, M=1_000_000, T=500, s_init=1.0):
815+ """Generate distribution of firm productivities after T periods."""
816+
817+ # Create random keys for each firm
756818 key = random.PRNGKey(seed)
757819 keys = random.split(key, M)
758820
759- draws = vectorized_single_draw(
760- keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init
761- )
821+ @jax.jit
822+ def single_firm_path(firm_key):
823+ # Generate path and return final productivity
824+ path = generate_path(
825+ firm_product_update,
826+ initial_state=s_init,
827+ num_steps=T,
828+ model=model,
829+ key=firm_key
830+ )
831+ return path[-1]
762832
763- return draws
764- ```
833+ # Apply to all firms in parallel
834+ product_dist = vmap(single_firm_path)(keys)
765835
766- ``` {code-cell} ipython3
767- # Generate the observations
768- data = generate_draws()
836+ return product_dist
837+
838+ # Generate the data
839+ data = generate_firm_distribution(FirmDynamicsModel())
769840```
770841
771- Now we produce the rank-size plot:
842+ Let's produce the rank-size plot
772843
773844``` {code-cell} ipython3
774845fig, ax = plt.subplots()
0 commit comments