From 08674b3b308b878c03396bbb73c3a4c78ddf0c52 Mon Sep 17 00:00:00 2001 From: Julien Cornebise Date: Mon, 8 Jul 2024 18:23:37 +0100 Subject: [PATCH] Log progress via logger when no progress bar Output terminals that do not support cursor movement ANSI control sequences do not support rich ProgressBars, and thus need progressbar=false. However, it is still very worthy to know how the SMC tasks are progressing. Therefore, when progress bars are deactivated, any chance to the state of the chains is logged to the default logger, using the exact same state dictionary as used for the progressbars. The fact we only log when a change happens means we do not flood the log. The use of the logging facility also allows the called to add timestamps in the logging handlers for a sense of elapsing time. --- pymc/smc/sampling.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 03e64f94c..950c6b8a9 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -345,7 +345,8 @@ def _sample_smc_int( while smc.beta < 1: smc.update_beta_and_weights() - progress_dict[task_id] = {"stage": stage, "beta": smc.beta} + # Index by chain because task_id is None if no progressbar is present + progress_dict[chain] = {"stage": stage, "beta": smc.beta, "task_id": task_id} smc.resample() smc.tune() @@ -378,6 +379,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): disable=not progressbar, ) as progress: futures = [] # keep track of the jobs + _log = logging.getLogger(__name__) with multiprocessing.Manager() as manager: # this is the key - we share some state between our # main process and our worker functions @@ -391,6 +393,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): for c in range(chains): # iterate over the jobs we need to run # set visible false so we don't have a lot of bars all at once: task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0") + if not progressbar: + _log.info(f"Queueing Chain {c} Stage: 0 Beta: 0") futures.append( executor.submit( _sample_smc_int, @@ -406,17 +410,26 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): # monitor the progress: done = [] remaining = futures + previous = {c: {} for c in range(chains)} while len(remaining) > 0: finished, remaining = wait(remaining, timeout=0.1) done.extend(finished) - for task_id, update_data in _progress.items(): + for chain, update_data in _progress.items(): stage = update_data["stage"] beta = update_data["beta"] + task_id = update_data["task_id"] + # update the progress bar for this task: progress.update( status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True, ) + # use the logger if there is no progress bar and data has changed: + if not progressbar: + # only log if the stage has changed + if previous[chain].get("stage", -1) != stage: + _log.info(f"Chain: {chain} Stage: {stage} Beta: {beta:.3f}") + previous[chain] = {"stage": stage, "beta": beta} return tuple(cloudpickle.loads(r.result()) for r in done)