Skip to content

Commit dc653e4

Browse files
committed
add setting for tqdm mininterval
1 parent ccad76e commit dc653e4

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

bayesflow/default_settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,8 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
199199

200200

201201
MMD_BANDWIDTH_LIST = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 1e3, 1e4, 1e5, 1e6]
202+
203+
# Minimum time interval between tqdm status updates to reduce
204+
# load. Only respected when refresh=False in set_postfix
205+
# and set_postfix_str
206+
TQDM_MININTERVAL = 0.1 # in seconds

bayesflow/simulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
logging.basicConfig()
3131

32-
from bayesflow.default_settings import DEFAULT_KEYS
32+
from bayesflow.default_settings import DEFAULT_KEYS, TQDM_MININTERVAL
3333
from bayesflow.diagnostics import plot_prior2d
3434
from bayesflow.exceptions import ConfigurationError
3535

@@ -1109,7 +1109,9 @@ def presimulate_and_save(
11091109
# Generate the presimulation files
11101110
file_counter = extend_from
11111111
for i in range(total_files):
1112-
with tqdm(total=batches_per_file, desc=f"Batches generated for file {i+1}") as p_bar:
1112+
with tqdm(
1113+
total=batches_per_file, desc=f"Batches generated for file {i+1}", mininterval=TQDM_MININTERVAL
1114+
) as p_bar:
11131115
file_list = [{} for _ in range(batches_per_file)]
11141116
for k in range(batches_per_file):
11151117
file_list[k] = self.__call__(batch_size=batch_size)

bayesflow/trainers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from bayesflow.computational_utilities import maximum_mean_discrepancy
3636
from bayesflow.configuration import *
37-
from bayesflow.default_settings import DEFAULT_KEYS, OPTIMIZER_DEFAULTS
37+
from bayesflow.default_settings import DEFAULT_KEYS, OPTIMIZER_DEFAULTS, TQDM_MININTERVAL
3838
from bayesflow.diagnostics import plot_latent_space_2d, plot_sbc_histograms
3939
from bayesflow.exceptions import ArgumentError, SimulationError
4040
from bayesflow.helper_classes import (
@@ -432,7 +432,7 @@ def train_online(
432432

433433
# Loop through training epochs
434434
for ep in range(1, epochs + 1):
435-
with tqdm(total=iterations_per_epoch, desc=f"Training epoch {ep}") as p_bar:
435+
with tqdm(total=iterations_per_epoch, desc=f"Training epoch {ep}", mininterval=TQDM_MININTERVAL) as p_bar:
436436
for it in range(1, iterations_per_epoch + 1):
437437
# Perform one training step and obtain current loss value
438438
loss = self._train_step(batch_size, update_step=_backprop_step, **kwargs)
@@ -450,7 +450,7 @@ def train_online(
450450
disp_str = format_loss_string(ep, it, loss, avg_dict, lr=lr)
451451

452452
# Update progress bar
453-
p_bar.set_postfix_str(disp_str)
453+
p_bar.set_postfix_str(disp_str, refresh=False)
454454
p_bar.update(1)
455455

456456
# Store and compute validation loss, if specified
@@ -558,7 +558,9 @@ def train_offline(
558558

559559
# Loop through epochs
560560
for ep in range(1, epochs + 1):
561-
with tqdm(total=data_set.num_batches, desc="Training epoch {}".format(ep)) as p_bar:
561+
with tqdm(
562+
total=data_set.num_batches, desc="Training epoch {}".format(ep), mininterval=TQDM_MININTERVAL
563+
) as p_bar:
562564
# Loop through dataset
563565
for bi, forward_dict in enumerate(data_set, start=1):
564566
# Perform one training step and obtain current loss value
@@ -578,7 +580,7 @@ def train_offline(
578580
disp_str = format_loss_string(ep, bi, loss, avg_dict, lr=lr, it_str="Batch")
579581

580582
# Update progress
581-
p_bar.set_postfix_str(disp_str)
583+
p_bar.set_postfix_str(disp_str, refresh=False)
582584
p_bar.update(1)
583585

584586
# Store and compute validation loss, if specified
@@ -731,7 +733,7 @@ def train_from_presimulation(
731733
f"Loading a simulation file resulted in a {type(epoch_data)}. Must be a dictionary or a list!"
732734
)
733735

734-
with tqdm(total=len(index_list), desc=f"Training epoch {ep}") as p_bar:
736+
with tqdm(total=len(index_list), desc=f"Training epoch {ep}", mininterval=TQDM_MININTERVAL) as p_bar:
735737
for it, index in enumerate(index_list, start=1):
736738
# Perform one training step and obtain current loss value
737739
input_dict = self.configurator(epoch_data[index])
@@ -756,7 +758,7 @@ def train_from_presimulation(
756758
disp_str = format_loss_string(ep, it, loss, avg_dict, lr=lr)
757759

758760
# Update progress bar
759-
p_bar.set_postfix_str(disp_str)
761+
p_bar.set_postfix_str(disp_str, refresh=False)
760762
p_bar.update(1)
761763

762764
# Store after each epoch, if specified
@@ -873,7 +875,7 @@ def train_experience_replay(
873875

874876
# Loop through epochs
875877
for ep in range(1, epochs + 1):
876-
with tqdm(total=iterations_per_epoch, desc=f"Training epoch {ep}") as p_bar:
878+
with tqdm(total=iterations_per_epoch, desc=f"Training epoch {ep}", mininterval=TQDM_MININTERVAL) as p_bar:
877879
for it in range(1, iterations_per_epoch + 1):
878880
# Simulate a batch of data and store into buffer
879881
input_dict = self._forward_inference(
@@ -900,7 +902,7 @@ def train_experience_replay(
900902
disp_str = format_loss_string(ep, it, loss, avg_dict, lr=lr)
901903

902904
# Update progress bar
903-
p_bar.set_postfix_str(disp_str)
905+
p_bar.set_postfix_str(disp_str, refresh=False)
904906
p_bar.update(1)
905907

906908
# Store and compute validation loss, if specified
@@ -1099,7 +1101,7 @@ def mmd_hypothesis_test(
10991101
num_reference = reference_summary.shape[0]
11001102

11011103
mmd_null_samples = np.empty(num_null_samples, dtype=np.float32)
1102-
for i in tqdm(range(num_null_samples)):
1104+
for i in tqdm(range(num_null_samples), mininterval=TQDM_MININTERVAL):
11031105
if bootstrap:
11041106
bootstrap_idx = np.random.randint(0, num_reference, size=num_observed)
11051107
simulated_summary = tf.gather(reference_summary, bootstrap_idx, axis=0)

0 commit comments

Comments
 (0)