Skip to content

Commit 438eda2

Browse files
authored
Apply suggestions from code review
pep8 compliance
1 parent 9650869 commit 438eda2

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

lectures/mccall_fitted_vfi.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def create_mccall_model(c=1,
227227
@jax.jit
228228
def update(model, v, d):
229229
"""Update value function and continuation value."""
230+
230231
# Unpack model parameters
231232
c, α, β, w_grid, w_draws = model
232233
u = jnp.log
@@ -377,12 +378,14 @@ def compute_res_wage_given_s(s, m=2.0, seed=1234):
377378
a, b = m - s, m + s
378379
key = jax.random.PRNGKey(seed)
379380
uniform_draws = jax.random.uniform(key, shape=(10_000,), minval=a, maxval=b)
381+
380382
# Create model with default parameters but replace wage draws
381383
model = create_mccall_model(w_draws=uniform_draws)
382384
w_bar = compute_reservation_wage(model)
383385
return w_bar
384386
385387
s_vals = jnp.linspace(1.0, 2.0, 15)
388+
386389
# Use vmap with different seeds for each s value
387390
seeds = jnp.arange(len(s_vals))
388391
compute_vectorized = jax.vmap(compute_res_wage_given_s, in_axes=(0, None, 0))

0 commit comments

Comments
 (0)