Skip to content

Commit ef9b075

Browse files
committed
Fix JAX concretization errors in scan functions
- Replace jnp.arange(n-1) with length=n-1 in jax.lax.scan calls - Add static_argnums=(2,) to wealth_time_series jit compilation - Use None as xs input since scan functions don't need iteration values Fixes ConcretizationTypeError when JAX tries to compile functions with dynamic array sizes in scan operations.
1 parent 7dfe7ee commit ef9b075

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

lectures/wealth_dynamics.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,14 @@ def wealth_time_series(wdy, w_0, n, key):
378378
key1, key2 = jr.split(key)
379379
z_0 = wdy.z_mean + jnp.sqrt(wdy.z_var) * jr.normal(key1)
380380
381-
# Use scan to generate time series
382-
_, w_series = jax.lax.scan(scan_fn, (w_0, z_0, key2), jnp.arange(n-1))
381+
# Use scan to generate time series - use None array for xs since we don't need the input
382+
_, w_series = jax.lax.scan(scan_fn, (w_0, z_0, key2), None, length=n-1)
383383
384384
# Prepend initial value
385385
return jnp.concatenate([jnp.array([w_0]), w_series])
386386
387-
# JIT compile for performance
388-
wealth_time_series = jax.jit(wealth_time_series)
387+
# JIT compile for performance with static argument
388+
wealth_time_series = jax.jit(wealth_time_series, static_argnums=(2,))
389389
```
390390

391391
Now here's a function to simulate a cross section of households forward in time.
@@ -428,7 +428,7 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None):
428428
429429
# Simulate forward
430430
(final_w, _, _), _ = jax.lax.scan(
431-
scan_fn, (w_0, z_0, sim_key), jnp.arange(shift_length)
431+
scan_fn, (w_0, z_0, sim_key), None, length=shift_length
432432
)
433433
return final_w
434434

0 commit comments

Comments
 (0)