Skip to content

Commit e6fbb11

Browse files
committed
fix: progress bar display
- Progress bar percentage is refreshed after completion to display progress of 100% - Add progress bar details of the number of paths complete
1 parent 90e2b7a commit e6fbb11

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from pytensor.tensor import TensorConstant, TensorVariable
6161
from rich.console import Console, Group
6262
from rich.padding import Padding
63+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
6364
from rich.table import Table
6465
from rich.text import Text
6566

@@ -1521,12 +1522,20 @@ def multipath_pathfinder(
15211522
results = []
15221523
compute_start = time.time()
15231524
try:
1524-
with CustomProgress(
1525+
desc = f"Paths Complete: {{path_idx}}/{num_paths}"
1526+
progress = CustomProgress(
1527+
"[progress.description]{task.description}",
1528+
BarColumn(),
1529+
"[progress.percentage]{task.percentage:>3.0f}%",
1530+
TimeRemainingColumn(),
1531+
TextColumn("/"),
1532+
TimeElapsedColumn(),
15251533
console=Console(theme=default_progress_theme),
15261534
disable=not progressbar,
1527-
) as progress:
1528-
task = progress.add_task("Fitting", total=num_paths)
1529-
for result in generator:
1535+
)
1536+
with progress:
1537+
task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
1538+
for path_idx, result in enumerate(generator, start=1):
15301539
try:
15311540
if isinstance(result, Exception):
15321541
raise result
@@ -1552,7 +1561,14 @@ def multipath_pathfinder(
15521561
lbfgs_status=LBFGSStatus.LBFGS_FAILED,
15531562
)
15541563
)
1555-
progress.update(task, advance=1)
1564+
finally:
1565+
# TODO: display LBFGS and Path Status in real time
1566+
progress.update(
1567+
task,
1568+
description=desc.format(path_idx=path_idx),
1569+
completed=path_idx,
1570+
refresh=True,
1571+
)
15561572
except (KeyboardInterrupt, StopIteration) as e:
15571573
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
15581574
if isinstance(e, StopIteration):

0 commit comments

Comments
 (0)