- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.1k
Show one progress bar per chain when sampling #7634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this look like when 1) you have another step sampler in the mix and 2) there's no nuts at all or 3) there are more than one NUTS step samplers?
        
          
                pymc/util.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to Oriol this should just be 20xx-present
| Doesn't need to be this PR but would be nice to show a relevant statistic for each sampler (or at least for when a single non NUTS sampler is being used). Conversely not showing these columns when there's no NUTS, as it gives a false sense of everything is going great | 
| test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-07.22-09-52.mp4Here's a comparison between NUTS and non-nuts sampler. Ideally we'd add a method to the step samplers themselves that would return the rich columns that sampler wants to use, then we just gather them and display. In that case you could even different sampler stats from different steps in the same run. Maybe it's worth doing. The actual code for this PR is pretty gnarly. | 
| I moved the responsibility for setting up the progressbars and updating stats to the step samplers. This means each step method can choose what stats are to be shown on the progress bars, and we can also combine them. Example vid attached. test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-08.18-16-37.mp4This is a pretty big scope creep for this PR, so I'm not against reverting these changes and going with something more basic. If we like it though I can lean into it. I will say it's broken right now because when you have e.g. multiple metropolis steps (one per variable) the only stats that get reported are the last one. It needs some logic on how to aggregate the stats across samplers with the same stats. | 
| The step sampler specifics looks amazing 😍 Gonna give it a try today. I'll test it but I assume things behave gracefully if the step samplers don't specify the display columns info? | 
| No it will break. I need to put in a default for the base class. It just needs to return empty stuff. | 
| I would still like to see the global runtime and /eta like we had before. Is that feasible or too ugly? Re: repeated samplers, show the mean? Or maybe only display specialized info when a single step sampler is being used? | 
| I added the base impl, so things will go gracefully if there's no implementation. This would only show the NUTS stats for example, because there's no implementation for BinaryGibbsMetropolis: import pymc as pm
with pm.Model() as m:
    x = pm.Bernoulli('x', p=0.5)
    y = pm.Normal('y', mu=pm.math.switch(x, -3, 3), sigma=10, shape=(10,))
    
    idata = pm.sample(step=[pm.BinaryGibbsMetropolis(x), pm.NUTS(y)], tune=2000, draws=2000, chains=8, cores=8, compile_kwargs={'mode':'NUMBA'})Re: global, yes we can keep it. But we can't have it as a single long bar that breaks the columns, because there's no colspan operator for rich tables (see Textualize/rich#164). We could make a separate table though. It just won't be as pretty as nutpie. I was thinking about the mean as well. If it only shows up when there's a single step sampler it would be pretty rare that anyone would use it, because the non-NUTS samplers pretty much always show up as one per variable. We might also need some priority logic to decide what to show if too many stats get involved. You can see just NUTS + Metropolis already breaks the table. We could do a LOW/MEDIUM/HIGH priority for displaying stats, and only at max 5 ever get displayed? | 
3db028d    to
    65af907      
    Compare
  
    | Sequential sampling (cores=1) still has the old approach. It has one bar per chain but not the stats | 
| 
 What if we show as a column per chain then?  | 
5f5c648    to
    667c78e      
    Compare
  
    | New version with timing info: test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-12.19-44-52.mp4 | 
| 
 Did you address this? | 
| Not yet, but it will be an easy fix. | 
| Some failing tests as well. Otherwise I'm happy with the changes. I'll paste in the discord to see if anybody has big complaints | 
| This is looks really nice and modern and its very informative, but do we have an option for a single progress bar with less information. | 
| We can do that painlessly yeah | 
| This looks great. I assume blue means tuning and red means post-tuning? If so, I wonder if red is the best color choice as it suggests something gone wrong. Maybe replace red with green? Or make tuning red and sampling blue? | 
| 
 It turns red if there's any divergence | 
| I see, maybe then green post-tuning without divergences? Or maybe a non-colorblind color. | 
| If you want some colorblind-friendly palletes https://github.com/arviz-devs/arviz-plots/tree/main/src/arviz_plots/styles | 
| I like it a lot. I don't know if we need different colors for pre-/post-tuning. Can we get red for any warning, not just divergences (so that it makes users read the warning)? Definitely no green if we are using red. I like blue/red for clean/warning. | 
| I think 2 colors is enough. It will be clear when you use compared to a gif. Otherwise a single color. @fonnesbeck what sort of warnings are you thinking about? Are they emmited during sampling or only at the end? | 
| I think we need to update the docstring here | 
| Usually needs one more call at the end, weren't we doing that before already? | 
| It only happens in the combined mode, I have to fix the logic there. | 
| This looks great 🚀 ! | 
| I want to rename the options for the progress bar to be less random, I was thinking about: 
 Any thoughts? | 
| combined/split +-stats: 
 | 
| 
 Maybe instead of  | 
| 
 I wanted to avoid having additional keyword arguments to  | 
| Congrats @jessegrabowski! | 
* One progress bar per chain when samplings * Add guard against divide by zero when computing draws per second * No more purple * Step samplers are responsible for setting up progress bars * Fix typos * Add progressbar defaults to BlockedStep ABC * pre-commit * Only update NUTS divergence stats after tuning * Add `Elapsed` and `Remaining` columns * Remove green color when chain finishes * Create `ProgressManager` class to handle progress bars * Yield `stats` from `_iter_sample` * Use `ProgressManager` in `_sample_many` * pre-commit * Explicit case handling for `progressbar` argument * Allow all permutations of arguments to progressbar * Appease mypy * Add True case * Fix final count when `progress = "combined"` * Update docstrings * mypy + cleanup * Syntax error in typehint * Simplify progressbar choices, update docstring * Incorporate feedback * Be verbose with progressbar settings


Description
I really like what nutpie gives you while sampling, so I tried to make something using rich that copies it. Example:
test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-07.10-57-56.mp4
Features are:
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7634.org.readthedocs.build/en/7634/