Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 6, 2025

Description

I really like what nutpie gives you while sampling, so I tried to make something using rich that copies it. Example:

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-07.10-57-56.mp4

Features are:

  1. One progress bar per chain
  2. Sampling statistics per chain. I copied nutpie, but we can haggle over what these should be (or give the user more control)
  3. Color change based on status. Blue when sampling, turns red after a divergence. Finished bar is either green (no divergences) or purple (with divergences).

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7634.org.readthedocs.build/en/7634/

@jessegrabowski jessegrabowski changed the title Show one progress bars per chain when sampling Show one progress bar per chain when sampling Jan 6, 2025
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this look like when 1) you have another step sampler in the mix and 2) there's no nuts at all or 3) there are more than one NUTS step samplers?

pymc/util.py Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to Oriol this should just be 20xx-present

@ricardoV94
Copy link
Member

Doesn't need to be this PR but would be nice to show a relevant statistic for each sampler (or at least for when a single non NUTS sampler is being used).

Conversely not showing these columns when there's no NUTS, as it gives a false sense of everything is going great

@jessegrabowski
Copy link
Member Author

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-07.22-09-52.mp4

Here's a comparison between NUTS and non-nuts sampler.

Ideally we'd add a method to the step samplers themselves that would return the rich columns that sampler wants to use, then we just gather them and display. In that case you could even different sampler stats from different steps in the same run. Maybe it's worth doing. The actual code for this PR is pretty gnarly.

@jessegrabowski
Copy link
Member Author

I moved the responsibility for setting up the progressbars and updating stats to the step samplers. This means each step method can choose what stats are to be shown on the progress bars, and we can also combine them. Example vid attached.

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-08.18-16-37.mp4

This is a pretty big scope creep for this PR, so I'm not against reverting these changes and going with something more basic. If we like it though I can lean into it.

I will say it's broken right now because when you have e.g. multiple metropolis steps (one per variable) the only stats that get reported are the last one. It needs some logic on how to aggregate the stats across samplers with the same stats.

@ricardoV94
Copy link
Member

The step sampler specifics looks amazing 😍 Gonna give it a try today.

I'll test it but I assume things behave gracefully if the step samplers don't specify the display columns info?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 8, 2025

No it will break. I need to put in a default for the base class. It just needs to return empty stuff.

@ricardoV94
Copy link
Member

I would still like to see the global runtime and /eta like we had before. Is that feasible or too ugly?

Re: repeated samplers, show the mean? Or maybe only display specialized info when a single step sampler is being used?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 8, 2025

I added the base impl, so things will go gracefully if there's no implementation. This would only show the NUTS stats for example, because there's no implementation for BinaryGibbsMetropolis:

import pymc as pm

with pm.Model() as m:
    x = pm.Bernoulli('x', p=0.5)
    y = pm.Normal('y', mu=pm.math.switch(x, -3, 3), sigma=10, shape=(10,))
    
    idata = pm.sample(step=[pm.BinaryGibbsMetropolis(x), pm.NUTS(y)], tune=2000, draws=2000, chains=8, cores=8, compile_kwargs={'mode':'NUMBA'})

Re: global, yes we can keep it. But we can't have it as a single long bar that breaks the columns, because there's no colspan operator for rich tables (see Textualize/rich#164).

We could make a separate table though. It just won't be as pretty as nutpie.

I was thinking about the mean as well. If it only shows up when there's a single step sampler it would be pretty rare that anyone would use it, because the non-NUTS samplers pretty much always show up as one per variable.

We might also need some priority logic to decide what to show if too many stats get involved. You can see just NUTS + Metropolis already breaks the table. We could do a LOW/MEDIUM/HIGH priority for displaying stats, and only at max 5 ever get displayed?

@ricardoV94
Copy link
Member

Sequential sampling (cores=1) still has the old approach. It has one bar per chain but not the stats

@ricardoV94
Copy link
Member

Re: global, yes we can keep it. But we can't have it as a single long bar that breaks the columns, because there's no colspan operator for rich tables (see Textualize/rich#164).

What if we show as a column per chain then? elapsed/left?

@jessegrabowski
Copy link
Member Author

New version with timing info:

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-12.19-44-52.mp4

@ricardoV94
Copy link
Member

Sequential sampling (cores=1) still has the old approach. It has one bar per chain but not the stats

Did you address this?

@jessegrabowski
Copy link
Member Author

Not yet, but it will be an easy fix.

@ricardoV94
Copy link
Member

Some failing tests as well. Otherwise I'm happy with the changes. I'll paste in the discord to see if anybody has big complaints

@aloctavodia
Copy link
Member

This is looks really nice and modern and its very informative, but do we have an option for a single progress bar with less information.

@jessegrabowski
Copy link
Member Author

We can do that painlessly yeah

@twiecki
Copy link
Member

twiecki commented Jan 15, 2025

This looks great. I assume blue means tuning and red means post-tuning? If so, I wonder if red is the best color choice as it suggests something gone wrong. Maybe replace red with green? Or make tuning red and sampling blue?

@ricardoV94
Copy link
Member

This looks great. I assume blue means tuning and red means post-tuning? If so, I wonder if red is the best color choice as it suggests something gone wrong. Maybe replace red with green? Or make tuning red and sampling blue?

It turns red if there's any divergence

@twiecki
Copy link
Member

twiecki commented Jan 15, 2025

I see, maybe then green post-tuning without divergences? Or maybe a non-colorblind color.

@aloctavodia
Copy link
Member

If you want some colorblind-friendly palletes https://github.com/arviz-devs/arviz-plots/tree/main/src/arviz_plots/styles

@fonnesbeck
Copy link
Member

fonnesbeck commented Jan 15, 2025

I like it a lot.

I don't know if we need different colors for pre-/post-tuning.

Can we get red for any warning, not just divergences (so that it makes users read the warning)?

Definitely no green if we are using red. I like blue/red for clean/warning.

@ricardoV94
Copy link
Member

I think 2 colors is enough. It will be clear when you use compared to a gif. Otherwise a single color.

@fonnesbeck what sort of warnings are you thinking about? Are they emmited during sampling or only at the end?

@tomicapretto
Copy link
Contributor

@tomicapretto
Copy link
Contributor

Have you noticed Draws stops at something smaller than the total number of draws?

image

You can also see it in your last video

image

@ricardoV94
Copy link
Member

Usually needs one more call at the end, weren't we doing that before already?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 25, 2025

It only happens in the combined mode, I have to fix the logic there.

@juanitorduz
Copy link
Contributor

This looks great 🚀 !

@jessegrabowski
Copy link
Member Author

I want to rename the options for the progress bar to be less random, I was thinking about:

  • one_bar for the combined mode
  • per_chain for the one bar per chain mode
  • full_stats for show all the stats possible
  • fewer_stats for show timing only

Any thoughts?

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 26, 2025

combined/split +-stats:

  • combined
  • split
  • combined+stats
  • split+stats

@fonnesbeck
Copy link
Member

I want to rename the options for the progress bar to be less random, I was thinking about:

  • one_bar for the combined mode
  • per_chain for the one bar per chain mode
  • full_stats for show all the stats possible
  • fewer_stats for show timing only

Any thoughts?

Maybe instead of full_stats vs fewer_stats just have a verbosity and have 0,1,2 as options?

@jessegrabowski
Copy link
Member Author

Maybe instead of full_stats vs fewer_stats just have a verbosity and have 0,1,2 as options?

I wanted to avoid having additional keyword arguments to pm.sample, because I agree with @cluhmann in general that's its way too bloated.

@jessegrabowski jessegrabowski merged commit 0db176c into pymc-devs:main Jan 27, 2025
25 checks passed
@jessegrabowski jessegrabowski deleted the more-progress branch January 27, 2025 07:07
@tomicapretto
Copy link
Contributor

Congrats @jessegrabowski!

vandalt pushed a commit to vandalt/pymc that referenced this pull request May 14, 2025
* One progress bar per chain when samplings

* Add guard against divide by zero when computing draws per second

* No more purple

* Step samplers are responsible for setting up progress bars

* Fix typos

* Add progressbar defaults to BlockedStep ABC

* pre-commit

* Only update NUTS divergence stats after tuning

* Add `Elapsed` and `Remaining` columns

* Remove green color when chain finishes

* Create `ProgressManager` class to handle progress bars

* Yield `stats` from `_iter_sample`

* Use `ProgressManager` in `_sample_many`

* pre-commit

* Explicit case handling for `progressbar` argument

* Allow all permutations of arguments to progressbar

* Appease mypy

* Add True case

* Fix final count when `progress = "combined"`

* Update docstrings

* mypy + cleanup

* Syntax error in typehint

* Simplify progressbar choices, update docstring

* Incorporate feedback

* Be verbose with progressbar settings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants