@@ -361,32 +361,9 @@ def update_states(wdy, w, z, key):
361361 return wp, zp, key
362362```
363363
364- We will use a general function for generating time series in an efficient JAX-compatible manner.
364+ We will use a specialized function to generate time series in an efficient JAX-compatible manner.
365365
366366``` {code-cell} ipython3
367- @partial(jax.jit, static_argnames=['n'])
368- def generate_path(f, initial_state, n, key, **kwargs):
369- """
370- Generate a time series by repeatedly applying an update rule.
371-
372- Args:
373- f: Update function with signature (state, t, key, **kwargs) -> (new_state, new_key)
374- initial_state: Initial state
375- n: Number of time steps to simulate
376- key: Initial JAX random key
377- **kwargs: Extra arguments passed to f
378-
379- Returns:
380- Array of shape (dim(state), n) containing the time series path
381- """
382- def update_wrapper(carry, t):
383- state, key = carry
384- new_state, new_key = f(state, t, key, **kwargs)
385- return (new_state, new_key), state
386-
387- _, path = jax.lax.scan(update_wrapper, (initial_state, key), jnp.arange(n))
388- return path
389-
390367def wealth_time_series_step(state, t, key, wdy):
391368 """
392369 Single time step for wealth time series simulation.
@@ -404,7 +381,7 @@ def wealth_time_series_step(state, t, key, wdy):
404381 wp, zp, new_key = update_states(wdy, w, z, key)
405382 return ((wp, zp), new_key)
406383
407- @jax.jit
384+ @partial( jax.jit, static_argnames=['n'])
408385def wealth_time_series(wdy, w_0, n, key):
409386 """
410387 Generate a single time series of length n for wealth given
@@ -424,9 +401,13 @@ def wealth_time_series(wdy, w_0, n, key):
424401 """
425402 key, subkey = jax.random.split(key)
426403 z_0 = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(subkey)
427- initial_state = (w_0, z_0)
428404
429- path = generate_path(wealth_time_series_step, initial_state, n, key, wdy=wdy)
405+ def update_wrapper(carry, t):
406+ state, key = carry
407+ new_state, new_key = wealth_time_series_step(state, t, key, wdy)
408+ return (new_state, new_key), state
409+
410+ _, path = jax.lax.scan(update_wrapper, ((w_0, z_0), key), jnp.arange(n))
430411 return path[0] # Return only wealth component
431412```
432413
@@ -435,7 +416,7 @@ Now here's function to simulate a cross section of households forward in time.
435416Note the use of JAX vectorization to speed up computation.
436417
437418``` {code-cell} ipython3
438- @jax.jit
419+ @partial( jax.jit, static_argnames=['shift_length'])
439420def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
440421 """
441422 Shifts a cross-section of households forward in time using JAX vectorization.
@@ -463,7 +444,7 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
463444 z_init = (wdy.z_mean +
464445 jnp.sqrt(wdy.z_var) * jax.random.normal(subkey, (num_households,)))
465446
466- # Create initial state array
447+ # Create initial state array [wealth, z]
467448 initial_states = jnp.column_stack([w_distribution, z_init])
468449
469450 def update_household(carry, t):
@@ -473,9 +454,12 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
473454 subkeys = jnp.array(subkeys)
474455
475456 # Vectorized update for all households
476- new_states = jax.vmap(lambda state, k: update_states(wdy, state[0], state[1], k)[:2])(
477- states, subkeys)
478- new_states = jnp.array(new_states)
457+ def single_household_update(state, k):
458+ w, z = state
459+ wp, zp, _ = update_states(wdy, w, z, k) # Ignore returned key
460+ return jnp.array([wp, zp])
461+
462+ new_states = jax.vmap(single_household_update)(states, subkeys)
479463
480464 return (new_states, key), None
481465
@@ -538,7 +522,9 @@ def generate_lorenz_and_gini(wdy, num_households=100_000, T=500, key=None):
538522 ψ_0 = jnp.full(num_households, wdy.y_mean)
539523
540524 ψ_star = update_cross_section(wdy, ψ_0, shift_length=T, key=key)
541- return qe.gini_coefficient(ψ_star), qe.lorenz_curve(ψ_star)
525+ # Convert JAX array to numpy for quantecon functions
526+ ψ_star_np = np.array(ψ_star)
527+ return qe.gini_coefficient(ψ_star_np), qe.lorenz_curve(ψ_star_np)
542528```
543529
544530Now we investigate how the Lorenz curves associated with the wealth distribution change as return to savings varies.
@@ -549,7 +535,7 @@ If you are running this yourself, note that it will take one or two minutes to e
549535
550536This is unavoidable because we are executing a CPU intensive task.
551537
552- In fact the code, which is JIT compiled and parallelized , runs extremely fast relative to the number of computations.
538+ In fact the code, which is JIT compiled by JAX and vectorized , runs extremely fast relative to the number of computations.
553539
554540``` {code-cell} ipython3
555541%%time
@@ -575,7 +561,7 @@ We will look at this again via the Gini coefficient immediately below, but
575561first consider the following image of our system resources when the code above
576562is executing:
577563
578- Since the code is both efficiently JIT compiled and fully parallelized , it's
564+ Since the code is both efficiently JIT compiled by JAX and fully vectorized , it's
579565close to impossible to make this sequence of tasks run faster without changing
580566hardware.
581567
@@ -729,7 +715,9 @@ Now let's see the rank-size plot:
729715``` {code-cell} ipython3
730716fig, ax = plt.subplots()
731717
732- rank_data, size_data = qe.rank_size(ψ_star, c=0.001)
718+ # Convert JAX array to numpy for quantecon functions
719+ ψ_star_np = np.array(ψ_star)
720+ rank_data, size_data = qe.rank_size(ψ_star_np, c=0.001)
733721ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
734722ax.set_xlabel("log rank")
735723ax.set_ylabel("log size")
0 commit comments