@@ -253,7 +253,7 @@ def rate_steady_state(model, tol=1e-6):
253253 return x
254254
255255
256- @partial(jax.jit, static_argnums=(2,) )
256+ @partial(jax.jit, static_argnames=['T'] )
257257def simulate_stock_path(model, X0, T):
258258 """
259259 Simulates the sequence of employment and unemployment stocks.
@@ -268,7 +268,7 @@ def simulate_stock_path(model, X0, T):
268268 _, X_path = jax.lax.scan(update_X, X0, jnp.arange(T))
269269 return X_path
270270
271- @partial(jax.jit, static_argnums=(2,) )
271+ @partial(jax.jit, static_argnames=['T'] )
272272def simulate_rate_path(model, x0, T):
273273 """
274274 Simulates the sequence of employment and unemployment rates.
@@ -467,7 +467,7 @@ We can investigate this by simulating the Markov chain.
467467Let's plot the path of the sample averages over 5,000 periods
468468
469469``` {code-cell} ipython3
470- @partial(jax.jit, static_argnums=(1,) )
470+ @partial(jax.jit, static_argnames=['T'] )
471471def simulate_markov_chain(P, T, init_state, key):
472472 """Simulate a Markov chain."""
473473 def step(carry, key):
0 commit comments