Skip to content

Comments

[lake_model] Update lecture to JAX and check compliance to style sheet#589

Merged
mmcky merged 16 commits intomainfrom
lake-update
Sep 23, 2025
Merged

[lake_model] Update lecture to JAX and check compliance to style sheet#589
mmcky merged 16 commits intomainfrom
lake-update

Conversation

@HumphreyYang
Copy link
Member

This PR updates lake_model lecture to JAX and check the compliance to style sheet.

This reduces a lot of code in the lecture!

@github-actions
Copy link

github-actions bot commented Sep 2, 2025

@github-actions github-actions bot temporarily deployed to pull request September 2, 2025 17:32 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 2, 2025 17:33 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 4, 2025 08:04 Inactive
@HumphreyYang HumphreyYang marked this pull request as ready for review September 4, 2025 08:10
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR modernizes the lake_model lecture by transitioning from NumPy/Numba to JAX and ensuring compliance with the project's style guidelines. The update significantly reduces code complexity while improving performance through JAX's JIT compilation.

Key changes:

  • Replaces NumPy/Numba with JAX for all numerical computations and JIT compilation
  • Refactors the object-oriented LakeModel class into a functional programming approach using NamedTuple and separate JIT-compiled functions
  • Updates code style to follow project conventions (lowercase titles, consistent formatting)

@github-actions github-actions bot temporarily deployed to pull request September 4, 2025 08:27 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 4, 2025 08:27 Inactive
@netlify

This comment was marked as outdated.

@github-actions

This comment was marked as outdated.

@github-actions github-actions bot temporarily deployed to pull request September 9, 2025 04:16 Inactive
@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (40c3703)

📚 Changed Lecture Pages: lake_model

tags: [output_scroll]
---
```{code-cell} ipython3
:tags: [output_scroll]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move to collapse-40 once QuantEcon/quantecon-book-theme#300 is ready.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HumphreyYang let's leave this for now and we can do a bulk search and conversion once the theme is improved.

@mmcky mmcky added the review label Sep 17, 2025
@HumphreyYang
Copy link
Member Author

Hi @jstac,

Many thanks for the detailed review and for setting the rules! I will apply those rules both to myself and during the review process.

I’ve pushed some changes based on your review:

  • removed unnecessary create_model functions and redundant functions
  • updated simulate_markov_chain
  • generated the distribution plot using code instead of inserting assets
  • added type hints to parameter-bearing instances such as LakeModel, McCallModel, and EconomyParameters so readers can distinguish them

I didn’t change the figure because in

axes[i].plot(x_path[:, i], lw=2, alpha=0.5)
axes[i].hlines(xbar[i], 0, T, 'r', '--')

ax.hlines restarts the color cycle. So if we don’t explicitly set the color, it ends up using the same blue as the lines from ax.plot. That’s why I had to assign a different color to make it visually distinct — even though I realize that goes against the style guide 😬.

But maybe we should just append to a Python list, or use the at[i].set(b) syntax. Thoughts? cc @HumphreyYang @mmcky @cc7768

I definitely think it’s overkill and lacks clarity, so I completely agree with the principle of keeping lectures simple.

To help the discussion, I ran some small experiments:

In the notebook below, I compared the list implementation with the scan version on GPU:
https://colab.research.google.com/gist/HumphreyYang/e42556b0a81f25367695707da4201bcf/for_scan_speed.ipynb

The for loop version takes 0.61 seconds, while the scan version takes 0.11 seconds.

In another notebook, I ran the same comparison on CPU:
https://colab.research.google.com/gist/HumphreyYang/c385e20d62d021200d49f2cd2b4e33db/for_scan_speed.ipynb

The for loop version takes 0.7 seconds, while the scan version takes 0.09 seconds.

I don’t think speed is really an issue here — but I am slightly concerned about not showing readers the best practice 🤔

Looking forward to hearing your thoughts on the best way to present this!

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (0e6e327)

📚 Changed Lecture Pages: lake_model

@jstac
Copy link
Contributor

jstac commented Sep 19, 2025

Many thanks @HumphreyYang for the thoughtful discussion.

Iterating with Claude I thought of the following possibility for handling lectures where it's necessary to generate sequences iteratively.

First, we have a generate_path function that handles lax.scan.

def generate_path(update_fn, initial_state, num_steps):
    """
    Generate a time series by repeatedly applying an update rule.
    
    This is a general-purpose function that can generate any kind of time series
    given an update rule and initial conditions. It uses JAX's scan function
    to efficiently apply the update rule for the specified number of steps.
    
    Args:
        update_fn: Function that takes (current_state, step_index) and returns 
                  (next_state, current_state). The second return value is stored
                  in the output path.
        initial_state: Starting state for the time series
        num_steps: Number of time steps to simulate
        
    Returns:
        Array containing the full time series path
    """
    # Use JAX scan to efficiently apply update_fn num_steps times
    # scan returns (final_state, full_path) - we only need the path
    _, path = jax.lax.scan(update_fn, initial_state, jnp.arange(num_steps))
    return path

The advantage is that we have this logic once and we get to explain it to the reader before we start using it.

Then we have the functions that use this generate_path function. E.g.,

@partial(jax.jit, static_argnames=['T'])
def simulate_stock_path(model, X0, T):
    """
    Simulates the sequence of employment and unemployment stocks.
    
    Args:
        model: Economic model containing parameters for transition matrices
        X0: Initial stock values (employment/unemployment levels)  
        T: Number of time periods to simulate
        
    Returns:
        Array of shape (T, len(X0)) containing stock values at each time step
    """
    # Extract transition matrix from the economic model
    A, A_hat, g = compute_matrices(model)
    
    def stock_update_rule(current_stocks, time_step):
        """Update rule: apply transition matrix to get next period's stocks."""
        next_stocks = A @ current_stocks
        return next_stocks, current_stocks
    
    # Ensure initial state is properly formatted
    initial_stocks = jnp.atleast_1d(X0)
    
    # Generate the stock path using our general path generator
    return generate_path(stock_update_rule, initial_stocks, T)

The only thing I don't love about this is that you need to remember the slightly odd structure of the stock_update_rule function -- its nonobvious outputs in particular. I wonder if some of this logic could be pushed into generate_path, so that it could work with a simpler function like

    def stock_update_rule(current_stocks, time_step):
        """Update rule: apply transition matrix to get next period's stocks."""
        next_stocks = A @ current_stocks
        return next_stocks

CC @mmcky This discussion is important because it will touch a lot of lectures. The issue is how to generate time paths in an array processing environment where all arrays are immutable.

Whatever we decide we should probably add it to the style manual.

@jstac
Copy link
Contributor

jstac commented Sep 19, 2025

this logic should probably be pushed into generate_path as well:

    # Ensure initial state is properly formatted
    initial_stocks = jnp.atleast_1d(X0)

@HumphreyYang
Copy link
Member Author

Whatever we decide, we should probably add it to the style manual.

Many thanks, @jstac — this makes a lot of sense to me! I think we can have something like this in the scan update function, wrapping around the single return function to record the carry state:

def scan_wrapper(state, x):
    next_state = update_fn(state, x)
    return next_state, state

That way, the update_fn can take in the single-return function stock_update_rule you suggested.

So the generate_path function will look like this:

def generate_path(update_fn, initial_state, num_steps):

    def scan_wrapper(state, x):
        next_state = update_fn(state, x)
        return next_state, state
    
    final_state, path = jax.lax.scan(scan_wrapper, initial_state, jnp.arange(num_steps))
    
    return path

This is pretty clean and removes the awkward "two" returns!

I will put this in and update the manual if this looks good to you. I imagine that we only need this in lectures where we have many different sequence generating logic.

@jstac
Copy link
Contributor

jstac commented Sep 22, 2025

Yep, that's perfect @HumphreyYang . Thanks!

I suggest update_wrapper instead of scan_wrapper. We should be able to jax.jit generate_path with the first argument static.

@HumphreyYang
Copy link
Member Author

I suggest update_wrapper instead of scan_wrapper. We should be able to jax.jit generate_path with the first argument static.

And num_steps static : ) I will work on this in this PR!

@jstac
Copy link
Contributor

jstac commented Sep 22, 2025

Thanks @HumphreyYang . Please add in a docstring to both the outer and inner functions -- that explains what it does so we can put it in the docs and then easily copy-paste it to other lectures.

Should update_fn(state, x) be update_fn(state, t)? This variable is stepping through the arange isn't it? So it's natural to think of it as a time index.

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (4a9c909)

📚 Changed Lecture Pages: lake_model

@jstac
Copy link
Contributor

jstac commented Sep 22, 2025

Many thanks @HumphreyYang . Very nice.

I slightly reorganized the code so that we can mix it more with discussion.

Regarding the rate_steady_state function, it's not obvious that we're computing an eigenvector. (In fact I'm not even sure how that works.) Also, tol is not used. Perhaps it would be best to change to an expliciit eigenvector calculation, in order to match the surrounding discussoin.

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (3ed4ac7)

📚 Changed Lecture Pages: lake_model

@jstac
Copy link
Contributor

jstac commented Sep 22, 2025

    Returns:
        Array of shape (T, dim(x)) containing the time series path
        [x_0, x_1, x_2, ..., x_{T-1}]

Atually the convention is (dim(x), T), so that each column is a state observation x_t. Let's transpose within the function.

@HumphreyYang
Copy link
Member Author

Many thanks @jstac! Your are definitely right. It was a bit too late last night 😬

@mmcky
Copy link
Contributor

mmcky commented Sep 23, 2025

thanks @HumphreyYang and @jstac -- just catching up with this discussion now the live site is fixed. I will use the final solution in this PR to come up with a pattern to add to QuantEcon.manual

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (dad7bb1)

📚 Changed Lecture Pages: lake_model

@HumphreyYang
Copy link
Member Author

Hi @jstac and @mmcky, I have pushed changes suggested by @jstac. Please let me know how to improve this template further!

@jstac
Copy link
Contributor

jstac commented Sep 23, 2025

This looks great. Love your work @HumphreyYang

@mmcky
Copy link
Contributor

mmcky commented Sep 23, 2025

thanks @HumphreyYang and @jstac

I will merge this PR this afternoon.

@jstac
Copy link
Contributor

jstac commented Sep 23, 2025

Thanks @mmcky . Just a friendly reminder to open an issue for relevant updates to the manual.

@mmcky
Copy link
Contributor

mmcky commented Sep 23, 2025

thanks @jstac - indeed. Writing that up now before I merge :-) 👍

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-589--sunny-cactus-210e3e.netlify.app (64111fc)

📚 Changed Lecture Pages: lake_model

@mmcky
Copy link
Contributor

mmcky commented Sep 23, 2025

Screenshot 2025-09-23 at 5 18 48 pm

I have looked into this Matplotlib warning. Once we get the new cache built -- this should disappear.

@mmcky mmcky merged commit 19cc20d into main Sep 23, 2025
1 check passed
@mmcky mmcky deleted the lake-update branch September 23, 2025 07:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants