Skip to content

Commit dab2635

Browse files
committed
updates
1 parent 7efaca6 commit dab2635

File tree

1 file changed

+59
-36
lines changed

1 file changed

+59
-36
lines changed

lectures/lake_model.md

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,8 @@ def simulate_markov_chain(P, T, init_state, key):
481481
def step(carry, key):
482482
state = carry
483483
probs = P[state]
484-
state_new = jax.random.choice(key, a=jnp.arange(len(probs)), p=probs)
484+
state_new = jax.random.choice(key,
485+
a=jnp.arange(len(probs)), p=probs)
485486
return state_new, state_new
486487
487488
keys = jax.random.split(key, T)
@@ -649,7 +650,8 @@ class McCallModel(NamedTuple):
649650
p_vec: jnp.ndarray # Probabilities over w_vec
650651
651652
652-
def create_mccall_model(α=0.2, β=0.98, γ=0.7, c=6.0, σ=2.0, w_vec=None, p_vec=None):
653+
def create_mccall_model(α=0.2, β=0.98, γ=0.7, c=6.0, σ=2.0,
654+
w_vec=None, p_vec=None):
653655
"""
654656
Create a McCallModel with default wage distribution if not provided.
655657
"""
@@ -741,36 +743,55 @@ Now let's compute and plot welfare, employment, unemployment, and tax revenue as
741743
function of the unemployment compensation rate
742744

743745
```{code-cell} ipython3
744-
# Some global variables that will stay constant
745-
α = 0.013
746-
α_q = (1-(1-α)**3) # Quarterly (α is monthly)
747-
b = 0.0124
748-
d = 0.00822
749-
β = 0.98
750-
γ = 1.0
751-
σ = 2.0
752-
753-
log_wage_mean, wage_grid_size, max_wage = 20, 200, 170
754-
w_vec_temp = jnp.linspace(1e-8, max_wage, wage_grid_size + 1)
755-
cdf = stats.norm.cdf(jnp.log(w_vec_temp),
756-
loc=jnp.log(log_wage_mean), scale=1)
757-
pdf = cdf[1:] - cdf[:-1]
758-
p_vec = pdf / pdf.sum()
759-
w_vec = (w_vec_temp[1:] + w_vec_temp[:-1]) / 2
746+
class EconomyParameters(NamedTuple):
747+
"""Parameters for the economy"""
748+
α: float
749+
α_q: float # Quarterly (α is monthly)
750+
b: float
751+
d: float
752+
β: float
753+
γ: float
754+
σ: float
755+
log_wage_mean: float
756+
wage_grid_size: int
757+
max_wage: float
758+
759+
def create_economy_params(α=0.013, b=0.0124, d=0.00822,
760+
β=0.98, γ=1.0, σ=2.0,
761+
log_wage_mean=20,
762+
wage_grid_size=200,
763+
max_wage=170):
764+
"""Create economy parameters with default values"""
765+
α_q = (1-(1-α)**3) # Convert monthly to quarterly
766+
return EconomyParameters(α=α, α_q=α_q, b=b, d=d, β=β, γ=γ, σ=σ,
767+
log_wage_mean=log_wage_mean,
768+
wage_grid_size=wage_grid_size,
769+
max_wage=max_wage)
770+
771+
def create_wage_distribution(params):
772+
"""Create wage distribution from parameters"""
773+
w_vec_temp = jnp.linspace(1e-8, params.max_wage,
774+
params.wage_grid_size + 1)
775+
cdf = stats.norm.cdf(jnp.log(w_vec_temp),
776+
loc=jnp.log(params.log_wage_mean), scale=1)
777+
pdf = cdf[1:] - cdf[:-1]
778+
p_vec = pdf / pdf.sum()
779+
w_vec = (w_vec_temp[1:] + w_vec_temp[:-1]) / 2
780+
return w_vec, p_vec
760781
761782
762783
@jax.jit
763-
def compute_optimal_quantities(c, τ):
784+
def compute_optimal_quantities(c, τ, params, w_vec, p_vec):
764785
"""
765786
Compute the reservation wage, job finding rate and value functions
766787
of the workers given c and τ.
767788
"""
768789
mcm = create_mccall_model(
769-
α=α_q,
770-
β=β,
771-
γ=γ,
790+
α=params.α_q,
791+
β=params.β,
792+
γ=params.γ,
772793
c=c-τ, # Post tax compensation
773-
σ=σ,
794+
σ=params.σ,
774795
w_vec=w_vec-τ, # Post tax wages
775796
p_vec=p_vec
776797
)
@@ -779,21 +800,22 @@ def compute_optimal_quantities(c, τ):
779800
w_idx = jnp.searchsorted(V - U, 0)
780801
w_bar = jnp.where(w_idx == len(V), jnp.inf, mcm.w_vec[w_idx])
781802
782-
λ = γ * jnp.sum(p_vec * (w_vec - τ > w_bar))
803+
λ = params.γ * jnp.sum(p_vec * (w_vec - τ > w_bar))
783804
return w_bar, λ, V, U
784805
785806
786807
@jax.jit
787-
def compute_steady_state_quantities(c, τ):
808+
def compute_steady_state_quantities(c, τ, params, w_vec, p_vec):
788809
"""
789810
Compute the steady state unemployment rate given c and τ using optimal
790811
quantities from the McCall model and computing corresponding steady
791812
state quantities
792813
"""
793-
w_bar, λ, V, U = compute_optimal_quantities(c, τ)
814+
w_bar, λ, V, U = compute_optimal_quantities(c, τ,
815+
params, w_vec, p_vec)
794816
795817
# Compute steady state employment and unemployment rates
796-
lm = create_lake_model(α=α_q, λ=λ, b=b, d=d)
818+
lm = create_lake_model(α=params.α_q, λ=λ, b=params.b, d=params.d)
797819
x = rate_steady_state(lm)
798820
u, e = x
799821
@@ -805,12 +827,13 @@ def compute_steady_state_quantities(c, τ):
805827
return e, u, welfare
806828
807829
808-
def find_balanced_budget_tax(c):
830+
def find_balanced_budget_tax(c, params, w_vec, p_vec):
809831
"""
810832
Find the tax level that will induce a balanced budget using bisection.
811833
"""
812834
def steady_state_budget(t):
813-
e, u, w = compute_steady_state_quantities(c, t)
835+
e, u, w = compute_steady_state_quantities(c, t,
836+
params, w_vec, p_vec)
814837
return t - u * c
815838
816839
# Use a simple bisection method
@@ -832,6 +855,10 @@ def find_balanced_budget_tax(c):
832855
return t_mid
833856
834857
858+
# Create economy parameters and wage distribution
859+
params = create_economy_params()
860+
w_vec, p_vec = create_wage_distribution(params)
861+
835862
# Levels of unemployment insurance we wish to study
836863
c_vec = jnp.linspace(5, 140, 60)
837864
@@ -841,8 +868,9 @@ empl_vec = []
841868
welfare_vec = []
842869
843870
for c in c_vec:
844-
t = find_balanced_budget_tax(c)
845-
e_rate, u_rate, welfare = compute_steady_state_quantities(c, t)
871+
t = find_balanced_budget_tax(c, params, w_vec, p_vec)
872+
e_rate, u_rate, welfare = compute_steady_state_quantities(c, t, params,
873+
w_vec, p_vec)
846874
tax_vec.append(t)
847875
unempl_vec.append(u_rate)
848876
empl_vec.append(e_rate)
@@ -875,7 +903,6 @@ In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameter
875903
This approach has several advantages:
876904
1. It's immutable, which aligns with JAX's functional programming paradigm
877905
2. Functions can be JIT-compiled for better performance
878-
3. It's easier to use with JAX's automatic differentiation
879906
880907
In this exercise, your task is to:
881908
1. Update parameters by creating a new instance of the model with the desired parameters (`α=0.02, λ=0.3`).
@@ -951,10 +978,6 @@ How long does the economy take to converge to its new steady state?
951978

952979
What is the new steady state level of employment?
953980

954-
```{note}
955-
It may be easier to use the class created in exercise 1 to help with changing variables.
956-
```
957-
958981
```{exercise-end}
959982
```
960983

0 commit comments

Comments
 (0)