[lake_model] Update lecture to JAX and check compliance to style sheet#589
[lake_model] Update lecture to JAX and check compliance to style sheet#589
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR modernizes the lake_model lecture by transitioning from NumPy/Numba to JAX and ensuring compliance with the project's style guidelines. The update significantly reduces code complexity while improving performance through JAX's JIT compilation.
Key changes:
- Replaces NumPy/Numba with JAX for all numerical computations and JIT compilation
- Refactors the object-oriented
LakeModelclass into a functional programming approach usingNamedTupleand separate JIT-compiled functions - Updates code style to follow project conventions (lowercase titles, consistent formatting)
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
|
📖 Netlify Preview Ready! Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (40c3703) 📚 Changed Lecture Pages: lake_model |
lectures/lake_model.md
Outdated
| tags: [output_scroll] | ||
| --- | ||
| ```{code-cell} ipython3 | ||
| :tags: [output_scroll] |
There was a problem hiding this comment.
I think we should move to collapse-40 once QuantEcon/quantecon-book-theme#300 is ready.
There was a problem hiding this comment.
@HumphreyYang let's leave this for now and we can do a bulk search and conversion once the theme is improved.
|
Hi @jstac, Many thanks for the detailed review and for setting the rules! I will apply those rules both to myself and during the review process. I’ve pushed some changes based on your review:
I didn’t change the figure because in axes[i].plot(x_path[:, i], lw=2, alpha=0.5)
axes[i].hlines(xbar[i], 0, T, 'r', '--')
I definitely think it’s overkill and lacks clarity, so I completely agree with the principle of keeping lectures simple. To help the discussion, I ran some small experiments: In the notebook below, I compared the list implementation with the The In another notebook, I ran the same comparison on CPU: The I don’t think speed is really an issue here — but I am slightly concerned about not showing readers the best practice 🤔 Looking forward to hearing your thoughts on the best way to present this! |
|
📖 Netlify Preview Ready! Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (0e6e327) 📚 Changed Lecture Pages: lake_model |
|
Many thanks @HumphreyYang for the thoughtful discussion. Iterating with Claude I thought of the following possibility for handling lectures where it's necessary to generate sequences iteratively. First, we have a def generate_path(update_fn, initial_state, num_steps):
"""
Generate a time series by repeatedly applying an update rule.
This is a general-purpose function that can generate any kind of time series
given an update rule and initial conditions. It uses JAX's scan function
to efficiently apply the update rule for the specified number of steps.
Args:
update_fn: Function that takes (current_state, step_index) and returns
(next_state, current_state). The second return value is stored
in the output path.
initial_state: Starting state for the time series
num_steps: Number of time steps to simulate
Returns:
Array containing the full time series path
"""
# Use JAX scan to efficiently apply update_fn num_steps times
# scan returns (final_state, full_path) - we only need the path
_, path = jax.lax.scan(update_fn, initial_state, jnp.arange(num_steps))
return pathThe advantage is that we have this logic once and we get to explain it to the reader before we start using it. Then we have the functions that use this @partial(jax.jit, static_argnames=['T'])
def simulate_stock_path(model, X0, T):
"""
Simulates the sequence of employment and unemployment stocks.
Args:
model: Economic model containing parameters for transition matrices
X0: Initial stock values (employment/unemployment levels)
T: Number of time periods to simulate
Returns:
Array of shape (T, len(X0)) containing stock values at each time step
"""
# Extract transition matrix from the economic model
A, A_hat, g = compute_matrices(model)
def stock_update_rule(current_stocks, time_step):
"""Update rule: apply transition matrix to get next period's stocks."""
next_stocks = A @ current_stocks
return next_stocks, current_stocks
# Ensure initial state is properly formatted
initial_stocks = jnp.atleast_1d(X0)
# Generate the stock path using our general path generator
return generate_path(stock_update_rule, initial_stocks, T)The only thing I don't love about this is that you need to remember the slightly odd structure of the stock_update_rule function -- its nonobvious outputs in particular. I wonder if some of this logic could be pushed into def stock_update_rule(current_stocks, time_step):
"""Update rule: apply transition matrix to get next period's stocks."""
next_stocks = A @ current_stocks
return next_stocksCC @mmcky This discussion is important because it will touch a lot of lectures. The issue is how to generate time paths in an array processing environment where all arrays are immutable. Whatever we decide we should probably add it to the style manual. |
|
this logic should probably be pushed into # Ensure initial state is properly formatted
initial_stocks = jnp.atleast_1d(X0) |
Many thanks, @jstac — this makes a lot of sense to me! I think we can have something like this in the scan update function, wrapping around the single return function to record the carry state: def scan_wrapper(state, x):
next_state = update_fn(state, x)
return next_state, stateThat way, the So the def generate_path(update_fn, initial_state, num_steps):
def scan_wrapper(state, x):
next_state = update_fn(state, x)
return next_state, state
final_state, path = jax.lax.scan(scan_wrapper, initial_state, jnp.arange(num_steps))
return pathThis is pretty clean and removes the awkward "two" returns! I will put this in and update the manual if this looks good to you. I imagine that we only need this in lectures where we have many different sequence generating logic. |
|
Yep, that's perfect @HumphreyYang . Thanks! I suggest update_wrapper instead of scan_wrapper. We should be able to jax.jit generate_path with the first argument static. |
And |
|
Thanks @HumphreyYang . Please add in a docstring to both the outer and inner functions -- that explains what it does so we can put it in the docs and then easily copy-paste it to other lectures. Should |
|
📖 Netlify Preview Ready! Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (4a9c909) 📚 Changed Lecture Pages: lake_model |
|
Many thanks @HumphreyYang . Very nice. I slightly reorganized the code so that we can mix it more with discussion. Regarding the |
|
📖 Netlify Preview Ready! Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (3ed4ac7) 📚 Changed Lecture Pages: lake_model |
Atually the convention is (dim(x), T), so that each column is a state observation x_t. Let's transpose within the function. |
|
Many thanks @jstac! Your are definitely right. It was a bit too late last night 😬 |
|
thanks @HumphreyYang and @jstac -- just catching up with this discussion now the |
|
📖 Netlify Preview Ready! Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (dad7bb1) 📚 Changed Lecture Pages: lake_model |
|
This looks great. Love your work @HumphreyYang |
|
thanks @HumphreyYang and @jstac I will merge this PR this afternoon. |
|
Thanks @mmcky . Just a friendly reminder to open an issue for relevant updates to the manual. |
|
thanks @jstac - indeed. Writing that up now before I merge :-) 👍 |
|
📖 Netlify Preview Ready! Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (64111fc) 📚 Changed Lecture Pages: lake_model |

This PR updates
lake_modellecture to JAX and check the compliance to style sheet.This reduces a lot of code in the lecture!