Commit ef9b075
committed
Fix JAX concretization errors in scan functions
- Replace jnp.arange(n-1) with length=n-1 in jax.lax.scan calls
- Add static_argnums=(2,) to wealth_time_series jit compilation
- Use None as xs input since scan functions don't need iteration values
Fixes ConcretizationTypeError when JAX tries to compile functions
with dynamic array sizes in scan operations.1 parent 7dfe7ee commit ef9b075
1 file changed
+5
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
378 | 378 | | |
379 | 379 | | |
380 | 380 | | |
381 | | - | |
382 | | - | |
| 381 | + | |
| 382 | + | |
383 | 383 | | |
384 | 384 | | |
385 | 385 | | |
386 | 386 | | |
387 | | - | |
388 | | - | |
| 387 | + | |
| 388 | + | |
389 | 389 | | |
390 | 390 | | |
391 | 391 | | |
| |||
428 | 428 | | |
429 | 429 | | |
430 | 430 | | |
431 | | - | |
| 431 | + | |
432 | 432 | | |
433 | 433 | | |
434 | 434 | | |
| |||
0 commit comments