Skip to content

Commit 7816bd6

Browse files
jstacclaude
andcommitted
fix: resolve build errors in ifp_egm and os_numerical
ifp_egm.md fixes: - Add missing initialization in solve_model_numpy (c_vals, ae_vals = c_vals_init, ae_vals_init) - Fix update function parameter order in simulate_household (t, state instead of state, t) - Fix jax.random.fold_in argument order (key, t instead of t, key) - Add .astype(jnp.int32) to z_next_idx to fix dtype mismatch - Update jax.vmap to use in_axes instead of axes - Add missing arguments to sim_all_households call - Fix compute_asset_stationary argument order in all calls os_numerical.md fixes: - Simplify maximize function signature from (g, upper_bound, args) to (g, upper_bound) - Remove unused args parameter and tuple unpacking All changes tested by converting to Python with jupytext and running successfully. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent f9dd477 commit 7816bd6

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

lectures/ifp_egm.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def solve_model_numpy(
450450
Solve the model using time iteration with EGM.
451451
452452
"""
453+
c_vals, ae_vals = c_vals_init, ae_vals_init
453454
i = 0
454455
error = tol + 1
455456
@@ -798,11 +799,11 @@ def simulate_household(
798799
σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx])
799800
800801
# Simulate forward T periods
801-
def update(state, t):
802+
def update(t, state):
802803
a, z_idx = state
803804
# Draw next shock z' from Π[z, z']
804-
current_key = jax.random.fold_in(t, key)
805-
z_next_idx = jax.random.choice(current_key, n_z, p=Π[z_idx])
805+
current_key = jax.random.fold_in(key, t)
806+
z_next_idx = jax.random.choice(current_key, n_z, p=Π[z_idx]).astype(jnp.int32)
806807
z_next = z_grid[z_next_idx]
807808
# Update assets: a' = R * (a - c) + Y'
808809
a_next = R * (a - σ(a, z_idx)) + y(z_next)
@@ -848,9 +849,9 @@ def compute_asset_stationary(
848849
keys = jax.random.split(key, num_households)
849850
# Vectorize simulate_household in (key, a_0, z_idx_0)
850851
sim_all_households = jax.vmap(
851-
simulate_household, axes=(0, 0, 0, None, None, None, None, None)
852+
simulate_household, in_axes=(0, 0, 0, None, None, None, None)
852853
)
853-
assets = sim_all_households(keys, a_0_vector, z_idx_0_vector)
854+
assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vals, ae_vals, ifp, T)
854855
855856
return np.array(assets)
856857
```
@@ -860,10 +861,10 @@ Now we call the function, generate the asset distribution and histogram it:
860861
```{code-cell} ipython3
861862
ifp = create_ifp()
862863
R, β, γ, Π, z_grid, s = ifp
863-
ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
864-
c_vals_init = ae_vals_init
864+
ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
865+
c_vals_init = ae_vals_init
865866
c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
866-
assets = compute_asset_stationary(ifp, c_vals, ae_vals)
867+
assets = compute_asset_stationary(c_vals, ae_vals, ifp)
867868
868869
fig, ax = plt.subplots()
869870
ax.hist(assets, bins=20, alpha=0.5, density=True)
@@ -911,9 +912,9 @@ fig, ax = plt.subplots()
911912
for r_val in r_vals:
912913
ifp = create_ifp(r=r_val)
913914
R, β, γ, Π, z_grid, s = ifp
914-
ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
915-
c_vals_init = ae_vals_init
916-
c_vals, ae_vals = solve_model(ifp, c_vals_init)
915+
ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
916+
c_vals_init = ae_vals_init
917+
c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
917918
# Plot policy
918919
ax.plot(ae_vals[:, 0], c_vals[:, 0], label=f'$r = {r_val:.3f}$')
919920
# Start next round with last solution
@@ -980,7 +981,7 @@ for r in r_vals:
980981
ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
981982
c_vals_init = ae_vals_init
982983
c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
983-
assets = compute_asset_stationary(ifp, c_vals, ae_vals, num_households=10_000, T=500)
984+
assets = compute_asset_stationary(c_vals, ae_vals, ifp, num_households=10_000, T=500)
984985
mean = np.mean(assets)
985986
asset_mean.append(mean)
986987
print(f' Mean assets: {mean:.4f}')

lectures/os_numerical.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,16 @@ The `maximize` function below is a small helper function that converts a
159159
SciPy minimization routine into a maximization routine.
160160

161161
```{code-cell} python3
162-
def maximize(g, upper_bound, args):
162+
def maximize(g, upper_bound):
163163
"""
164164
Maximize the function g over the interval [0, upper_bound].
165165
166166
We use the fact that the maximizer of g on any interval is
167-
also the minimizer of -g. The tuple args collects any extra
168-
arguments to g.
167+
also the minimizer of -g.
169168
170169
"""
171170
172-
objective = lambda x: -g(x, *args)
171+
objective = lambda x: -g(x)
173172
bounds = (0, upper_bound)
174173
result = minimize_scalar(objective, bounds=bounds, method='bounded')
175174
maximizer, maximum = result.x, -result.fun

0 commit comments

Comments
 (0)