Skip to content

Commit 9354978

Browse files
committed
add marimo compat progress bar for pymc sampler
1 parent 011fb35 commit 9354978

File tree

1 file changed

+139
-47
lines changed

1 file changed

+139
-47
lines changed

pymc/progress_bar.py

Lines changed: 139 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Iterable
15-
from typing import TYPE_CHECKING, Literal
15+
from typing import TYPE_CHECKING, Literal, Protocol, cast
1616

1717
from rich.box import SIMPLE_HEAD
1818
from rich.console import Console
@@ -192,6 +192,132 @@ def callbacks(self, task: "Task"):
192192
self.finished_style = self.default_finished_style
193193

194194

195+
class ProgressBar(Protocol):
196+
@property
197+
def tasks(self):
198+
"""Get the tasks in the progress bar."""
199+
200+
def add_task(self, *args, **kwargs):
201+
"""Add a task to the progress bar."""
202+
203+
def update(self, task_id, **kwargs):
204+
"""Update the task with the given ID with the provided keyword arguments."""
205+
206+
def __enter__(self):
207+
"""Enter the context manager."""
208+
209+
def __exit__(self, exc_type, exc_val, exc_tb):
210+
"""Exit the context manager."""
211+
212+
213+
def compute_draw_speed(elapsed, draws):
214+
speed = draws / max(elapsed, 1e-6)
215+
216+
if speed > 1 or speed == 0:
217+
unit = "draws/s"
218+
else:
219+
unit = "s/draws"
220+
speed = 1 / speed
221+
222+
return speed, unit
223+
224+
225+
def create_rich_progress_bar(full_stats, step_columns, progressbar, progressbar_theme):
226+
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]
227+
228+
if full_stats:
229+
columns += step_columns
230+
231+
columns += [
232+
TextColumn(
233+
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
234+
table_column=Column("Sampling Speed", ratio=1),
235+
),
236+
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
237+
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
238+
]
239+
240+
return CustomProgress(
241+
RecolorOnFailureBarColumn(
242+
table_column=Column("Progress", ratio=2),
243+
failing_color="tab:red",
244+
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
245+
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
246+
),
247+
*columns,
248+
console=Console(theme=progressbar_theme),
249+
disable=not progressbar,
250+
include_headers=True,
251+
)
252+
253+
254+
class MarimoProgressTask:
255+
def __init__(self, *args, **kwargs):
256+
self.args = args
257+
self.kwargs = kwargs
258+
259+
@property
260+
def chain_idx(self) -> int:
261+
return self.kwargs.get("chain_idx", 0)
262+
263+
@property
264+
def total(self):
265+
return self.kwargs.get("total", 0)
266+
267+
@property
268+
def elapsed(self):
269+
return self.kwargs.get("elapsed", 0)
270+
271+
272+
class MarimoProgressBar:
273+
def __init__(self) -> None:
274+
self.tasks = []
275+
self.divergences = {}
276+
277+
def __enter__(self):
278+
"""Enter the context manager."""
279+
import marimo as mo
280+
281+
total_draws = (self.tasks[0].total + 1) * len(self.tasks)
282+
283+
self.bar = mo.status.progress_bar(total=total_draws, title="Sampling PyMC model")
284+
285+
def __exit__(self, exc_type, exc_val, exc_tb):
286+
"""Exit the context manager."""
287+
self.bar._finish()
288+
289+
def add_task(self, *args, **kwargs):
290+
"""Add a task to the progress bar."""
291+
task = MarimoProgressTask(*args, **kwargs)
292+
self.tasks.append(task)
293+
return task
294+
295+
def update(self, task_id, **kwargs):
296+
"""Update the task with the given ID with the provided keyword arguments."""
297+
if self.bar.progress.current >= cast(int, self.bar.progress.total):
298+
return
299+
300+
self.divergences[task_id.chain_idx] = kwargs.get("divergences", 0)
301+
302+
total_divergences = sum(self.divergences.values())
303+
304+
update_kwargs = {}
305+
if total_divergences > 0:
306+
word = "draws" if total_divergences > 1 else "draw"
307+
update_kwargs["subtitle"] = f"{total_divergences} diverging {word}"
308+
309+
self.bar.progress.update(**update_kwargs)
310+
311+
312+
def in_marimo_notebook() -> bool:
313+
try:
314+
import marimo as mo
315+
316+
return mo.running_in_notebook()
317+
except ImportError:
318+
return False
319+
320+
195321
class ProgressBarManager:
196322
"""Manage progress bars displayed during sampling."""
197323

@@ -203,6 +329,7 @@ def __init__(
203329
tune: int,
204330
progressbar: bool | ProgressBarType = True,
205331
progressbar_theme: Theme | None = None,
332+
progress: ProgressBar | None = None,
206333
):
207334
"""
208335
Manage progress bars displayed during sampling.
@@ -275,11 +402,16 @@ def __init__(
275402

276403
progress_columns, progress_stats = step_method._progressbar_config(chains)
277404

278-
self._progress = self.create_progress_bar(
279-
progress_columns,
280-
progressbar=progressbar,
281-
progressbar_theme=progressbar_theme,
282-
)
405+
if in_marimo_notebook():
406+
self.combined_progress = False
407+
self._progress = MarimoProgressBar()
408+
else:
409+
self._progress = progress or create_rich_progress_bar(
410+
self.full_stats,
411+
progress_columns,
412+
progressbar=progressbar,
413+
progressbar_theme=progressbar_theme,
414+
)
283415
self.progress_stats = progress_stats
284416
self.update_stats_functions = step_method._make_progressbar_update_functions()
285417

@@ -331,18 +463,6 @@ def _initialize_tasks(self):
331463
for chain_idx in range(self.chains)
332464
]
333465

334-
@staticmethod
335-
def compute_draw_speed(elapsed, draws):
336-
speed = draws / max(elapsed, 1e-6)
337-
338-
if speed > 1 or speed == 0:
339-
unit = "draws/s"
340-
else:
341-
unit = "s/draws"
342-
speed = 1 / speed
343-
344-
return speed, unit
345-
346466
def update(self, chain_idx, is_last, draw, tuning, stats):
347467
if not self._show_progress:
348468
return
@@ -353,7 +473,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
353473
chain_idx = 0
354474

355475
elapsed = self._progress.tasks[chain_idx].elapsed
356-
speed, unit = self.compute_draw_speed(elapsed, draw)
476+
speed, unit = compute_draw_speed(elapsed, draw)
357477

358478
failing = False
359479
all_step_stats = {}
@@ -395,31 +515,3 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
395515
**all_step_stats,
396516
refresh=True,
397517
)
398-
399-
def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
400-
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]
401-
402-
if self.full_stats:
403-
columns += step_columns
404-
405-
columns += [
406-
TextColumn(
407-
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
408-
table_column=Column("Sampling Speed", ratio=1),
409-
),
410-
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
411-
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
412-
]
413-
414-
return CustomProgress(
415-
RecolorOnFailureBarColumn(
416-
table_column=Column("Progress", ratio=2),
417-
failing_color="tab:red",
418-
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
419-
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
420-
),
421-
*columns,
422-
console=Console(theme=progressbar_theme),
423-
disable=not progressbar,
424-
include_headers=True,
425-
)

0 commit comments

Comments
 (0)