@@ -481,7 +481,8 @@ def simulate_markov_chain(P, T, init_state, key):
481481 def step(carry, key):
482482 state = carry
483483 probs = P[state]
484- state_new = jax.random.choice(key, a=jnp.arange(len(probs)), p=probs)
484+ state_new = jax.random.choice(key,
485+ a=jnp.arange(len(probs)), p=probs)
485486 return state_new, state_new
486487
487488 keys = jax.random.split(key, T)
@@ -649,7 +650,8 @@ class McCallModel(NamedTuple):
649650 p_vec: jnp.ndarray # Probabilities over w_vec
650651
651652
652- def create_mccall_model(α=0.2, β=0.98, γ=0.7, c=6.0, σ=2.0, w_vec=None, p_vec=None):
653+ def create_mccall_model(α=0.2, β=0.98, γ=0.7, c=6.0, σ=2.0,
654+ w_vec=None, p_vec=None):
653655 """
654656 Create a McCallModel with default wage distribution if not provided.
655657 """
@@ -741,36 +743,55 @@ Now let's compute and plot welfare, employment, unemployment, and tax revenue as
741743function of the unemployment compensation rate
742744
743745``` {code-cell} ipython3
744- # Some global variables that will stay constant
745- α = 0.013
746- α_q = (1-(1-α)**3) # Quarterly (α is monthly)
747- b = 0.0124
748- d = 0.00822
749- β = 0.98
750- γ = 1.0
751- σ = 2.0
752-
753- log_wage_mean, wage_grid_size, max_wage = 20, 200, 170
754- w_vec_temp = jnp.linspace(1e-8, max_wage, wage_grid_size + 1)
755- cdf = stats.norm.cdf(jnp.log(w_vec_temp),
756- loc=jnp.log(log_wage_mean), scale=1)
757- pdf = cdf[1:] - cdf[:-1]
758- p_vec = pdf / pdf.sum()
759- w_vec = (w_vec_temp[1:] + w_vec_temp[:-1]) / 2
746+ class EconomyParameters(NamedTuple):
747+ """Parameters for the economy"""
748+ α: float
749+ α_q: float # Quarterly (α is monthly)
750+ b: float
751+ d: float
752+ β: float
753+ γ: float
754+ σ: float
755+ log_wage_mean: float
756+ wage_grid_size: int
757+ max_wage: float
758+
759+ def create_economy_params(α=0.013, b=0.0124, d=0.00822,
760+ β=0.98, γ=1.0, σ=2.0,
761+ log_wage_mean=20,
762+ wage_grid_size=200,
763+ max_wage=170):
764+ """Create economy parameters with default values"""
765+ α_q = (1-(1-α)**3) # Convert monthly to quarterly
766+ return EconomyParameters(α=α, α_q=α_q, b=b, d=d, β=β, γ=γ, σ=σ,
767+ log_wage_mean=log_wage_mean,
768+ wage_grid_size=wage_grid_size,
769+ max_wage=max_wage)
770+
771+ def create_wage_distribution(params):
772+ """Create wage distribution from parameters"""
773+ w_vec_temp = jnp.linspace(1e-8, params.max_wage,
774+ params.wage_grid_size + 1)
775+ cdf = stats.norm.cdf(jnp.log(w_vec_temp),
776+ loc=jnp.log(params.log_wage_mean), scale=1)
777+ pdf = cdf[1:] - cdf[:-1]
778+ p_vec = pdf / pdf.sum()
779+ w_vec = (w_vec_temp[1:] + w_vec_temp[:-1]) / 2
780+ return w_vec, p_vec
760781
761782
762783@jax.jit
763- def compute_optimal_quantities(c, τ):
784+ def compute_optimal_quantities(c, τ, params, w_vec, p_vec ):
764785 """
765786 Compute the reservation wage, job finding rate and value functions
766787 of the workers given c and τ.
767788 """
768789 mcm = create_mccall_model(
769- α=α_q,
770- β=β,
771- γ=γ,
790+ α=params. α_q,
791+ β=params. β,
792+ γ=params. γ,
772793 c=c-τ, # Post tax compensation
773- σ=σ,
794+ σ=params. σ,
774795 w_vec=w_vec-τ, # Post tax wages
775796 p_vec=p_vec
776797 )
@@ -779,21 +800,22 @@ def compute_optimal_quantities(c, τ):
779800 w_idx = jnp.searchsorted(V - U, 0)
780801 w_bar = jnp.where(w_idx == len(V), jnp.inf, mcm.w_vec[w_idx])
781802
782- λ = γ * jnp.sum(p_vec * (w_vec - τ > w_bar))
803+ λ = params. γ * jnp.sum(p_vec * (w_vec - τ > w_bar))
783804 return w_bar, λ, V, U
784805
785806
786807@jax.jit
787- def compute_steady_state_quantities(c, τ):
808+ def compute_steady_state_quantities(c, τ, params, w_vec, p_vec ):
788809 """
789810 Compute the steady state unemployment rate given c and τ using optimal
790811 quantities from the McCall model and computing corresponding steady
791812 state quantities
792813 """
793- w_bar, λ, V, U = compute_optimal_quantities(c, τ)
814+ w_bar, λ, V, U = compute_optimal_quantities(c, τ,
815+ params, w_vec, p_vec)
794816
795817 # Compute steady state employment and unemployment rates
796- lm = create_lake_model(α=α_q, λ=λ, b=b, d=d)
818+ lm = create_lake_model(α=params. α_q, λ=λ, b=params. b, d=params. d)
797819 x = rate_steady_state(lm)
798820 u, e = x
799821
@@ -805,12 +827,13 @@ def compute_steady_state_quantities(c, τ):
805827 return e, u, welfare
806828
807829
808- def find_balanced_budget_tax(c):
830+ def find_balanced_budget_tax(c, params, w_vec, p_vec ):
809831 """
810832 Find the tax level that will induce a balanced budget using bisection.
811833 """
812834 def steady_state_budget(t):
813- e, u, w = compute_steady_state_quantities(c, t)
835+ e, u, w = compute_steady_state_quantities(c, t,
836+ params, w_vec, p_vec)
814837 return t - u * c
815838
816839 # Use a simple bisection method
@@ -832,6 +855,10 @@ def find_balanced_budget_tax(c):
832855 return t_mid
833856
834857
858+ # Create economy parameters and wage distribution
859+ params = create_economy_params()
860+ w_vec, p_vec = create_wage_distribution(params)
861+
835862# Levels of unemployment insurance we wish to study
836863c_vec = jnp.linspace(5, 140, 60)
837864
@@ -841,8 +868,9 @@ empl_vec = []
841868welfare_vec = []
842869
843870for c in c_vec:
844- t = find_balanced_budget_tax(c)
845- e_rate, u_rate, welfare = compute_steady_state_quantities(c, t)
871+ t = find_balanced_budget_tax(c, params, w_vec, p_vec)
872+ e_rate, u_rate, welfare = compute_steady_state_quantities(c, t, params,
873+ w_vec, p_vec)
846874 tax_vec.append(t)
847875 unempl_vec.append(u_rate)
848876 empl_vec.append(e_rate)
@@ -875,7 +903,6 @@ In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameter
875903This approach has several advantages:
8769041. It's immutable, which aligns with JAX's functional programming paradigm
8779052. Functions can be JIT-compiled for better performance
878- 3. It's easier to use with JAX's automatic differentiation
879906
880907In this exercise, your task is to:
8819081. Update parameters by creating a new instance of the model with the desired parameters (`α=0.02, λ=0.3`).
@@ -951,10 +978,6 @@ How long does the economy take to converge to its new steady state?
951978
952979What is the new steady state level of employment?
953980
954- ``` {note}
955- It may be easier to use the class created in exercise 1 to help with changing variables.
956- ```
957-
958981``` {exercise-end}
959982```
960983
0 commit comments