Skip to content

Commit 85edd95

Browse files
Copilotmmcky
andcommitted
Complete JAX conversion and fix implementation issues
Co-authored-by: mmcky <[email protected]>
1 parent 734fb3b commit 85edd95

File tree

1 file changed

+24
-36
lines changed

1 file changed

+24
-36
lines changed

lectures/wealth_dynamics.md

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -361,32 +361,9 @@ def update_states(wdy, w, z, key):
361361
return wp, zp, key
362362
```
363363

364-
We will use a general function for generating time series in an efficient JAX-compatible manner.
364+
We will use a specialized function to generate time series in an efficient JAX-compatible manner.
365365

366366
```{code-cell} ipython3
367-
@partial(jax.jit, static_argnames=['n'])
368-
def generate_path(f, initial_state, n, key, **kwargs):
369-
"""
370-
Generate a time series by repeatedly applying an update rule.
371-
372-
Args:
373-
f: Update function with signature (state, t, key, **kwargs) -> (new_state, new_key)
374-
initial_state: Initial state
375-
n: Number of time steps to simulate
376-
key: Initial JAX random key
377-
**kwargs: Extra arguments passed to f
378-
379-
Returns:
380-
Array of shape (dim(state), n) containing the time series path
381-
"""
382-
def update_wrapper(carry, t):
383-
state, key = carry
384-
new_state, new_key = f(state, t, key, **kwargs)
385-
return (new_state, new_key), state
386-
387-
_, path = jax.lax.scan(update_wrapper, (initial_state, key), jnp.arange(n))
388-
return path
389-
390367
def wealth_time_series_step(state, t, key, wdy):
391368
"""
392369
Single time step for wealth time series simulation.
@@ -404,7 +381,7 @@ def wealth_time_series_step(state, t, key, wdy):
404381
wp, zp, new_key = update_states(wdy, w, z, key)
405382
return ((wp, zp), new_key)
406383
407-
@jax.jit
384+
@partial(jax.jit, static_argnames=['n'])
408385
def wealth_time_series(wdy, w_0, n, key):
409386
"""
410387
Generate a single time series of length n for wealth given
@@ -424,9 +401,13 @@ def wealth_time_series(wdy, w_0, n, key):
424401
"""
425402
key, subkey = jax.random.split(key)
426403
z_0 = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(subkey)
427-
initial_state = (w_0, z_0)
428404
429-
path = generate_path(wealth_time_series_step, initial_state, n, key, wdy=wdy)
405+
def update_wrapper(carry, t):
406+
state, key = carry
407+
new_state, new_key = wealth_time_series_step(state, t, key, wdy)
408+
return (new_state, new_key), state
409+
410+
_, path = jax.lax.scan(update_wrapper, ((w_0, z_0), key), jnp.arange(n))
430411
return path[0] # Return only wealth component
431412
```
432413

@@ -435,7 +416,7 @@ Now here's function to simulate a cross section of households forward in time.
435416
Note the use of JAX vectorization to speed up computation.
436417

437418
```{code-cell} ipython3
438-
@jax.jit
419+
@partial(jax.jit, static_argnames=['shift_length'])
439420
def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
440421
"""
441422
Shifts a cross-section of households forward in time using JAX vectorization.
@@ -463,7 +444,7 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
463444
z_init = (wdy.z_mean +
464445
jnp.sqrt(wdy.z_var) * jax.random.normal(subkey, (num_households,)))
465446
466-
# Create initial state array
447+
# Create initial state array [wealth, z]
467448
initial_states = jnp.column_stack([w_distribution, z_init])
468449
469450
def update_household(carry, t):
@@ -473,9 +454,12 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
473454
subkeys = jnp.array(subkeys)
474455
475456
# Vectorized update for all households
476-
new_states = jax.vmap(lambda state, k: update_states(wdy, state[0], state[1], k)[:2])(
477-
states, subkeys)
478-
new_states = jnp.array(new_states)
457+
def single_household_update(state, k):
458+
w, z = state
459+
wp, zp, _ = update_states(wdy, w, z, k) # Ignore returned key
460+
return jnp.array([wp, zp])
461+
462+
new_states = jax.vmap(single_household_update)(states, subkeys)
479463
480464
return (new_states, key), None
481465
@@ -538,7 +522,9 @@ def generate_lorenz_and_gini(wdy, num_households=100_000, T=500, key=None):
538522
ψ_0 = jnp.full(num_households, wdy.y_mean)
539523
540524
ψ_star = update_cross_section(wdy, ψ_0, shift_length=T, key=key)
541-
return qe.gini_coefficient(ψ_star), qe.lorenz_curve(ψ_star)
525+
# Convert JAX array to numpy for quantecon functions
526+
ψ_star_np = np.array(ψ_star)
527+
return qe.gini_coefficient(ψ_star_np), qe.lorenz_curve(ψ_star_np)
542528
```
543529

544530
Now we investigate how the Lorenz curves associated with the wealth distribution change as return to savings varies.
@@ -549,7 +535,7 @@ If you are running this yourself, note that it will take one or two minutes to e
549535

550536
This is unavoidable because we are executing a CPU intensive task.
551537

552-
In fact the code, which is JIT compiled and parallelized, runs extremely fast relative to the number of computations.
538+
In fact the code, which is JIT compiled by JAX and vectorized, runs extremely fast relative to the number of computations.
553539

554540
```{code-cell} ipython3
555541
%%time
@@ -575,7 +561,7 @@ We will look at this again via the Gini coefficient immediately below, but
575561
first consider the following image of our system resources when the code above
576562
is executing:
577563

578-
Since the code is both efficiently JIT compiled and fully parallelized, it's
564+
Since the code is both efficiently JIT compiled by JAX and fully vectorized, it's
579565
close to impossible to make this sequence of tasks run faster without changing
580566
hardware.
581567

@@ -729,7 +715,9 @@ Now let's see the rank-size plot:
729715
```{code-cell} ipython3
730716
fig, ax = plt.subplots()
731717
732-
rank_data, size_data = qe.rank_size(ψ_star, c=0.001)
718+
# Convert JAX array to numpy for quantecon functions
719+
ψ_star_np = np.array(ψ_star)
720+
rank_data, size_data = qe.rank_size(ψ_star_np, c=0.001)
733721
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
734722
ax.set_xlabel("log rank")
735723
ax.set_ylabel("log size")

0 commit comments

Comments
 (0)