3434)
3535from bayesflow .computational_utilities import maximum_mean_discrepancy
3636from bayesflow .configuration import *
37- from bayesflow .default_settings import DEFAULT_KEYS , OPTIMIZER_DEFAULTS
37+ from bayesflow .default_settings import DEFAULT_KEYS , OPTIMIZER_DEFAULTS , TQDM_MININTERVAL
3838from bayesflow .diagnostics import plot_latent_space_2d , plot_sbc_histograms
3939from bayesflow .exceptions import ArgumentError , SimulationError
4040from bayesflow .helper_classes import (
@@ -432,7 +432,7 @@ def train_online(
432432
433433 # Loop through training epochs
434434 for ep in range (1 , epochs + 1 ):
435- with tqdm (total = iterations_per_epoch , desc = f"Training epoch { ep } " ) as p_bar :
435+ with tqdm (total = iterations_per_epoch , desc = f"Training epoch { ep } " , mininterval = TQDM_MININTERVAL ) as p_bar :
436436 for it in range (1 , iterations_per_epoch + 1 ):
437437 # Perform one training step and obtain current loss value
438438 loss = self ._train_step (batch_size , update_step = _backprop_step , ** kwargs )
@@ -450,7 +450,7 @@ def train_online(
450450 disp_str = format_loss_string (ep , it , loss , avg_dict , lr = lr )
451451
452452 # Update progress bar
453- p_bar .set_postfix_str (disp_str )
453+ p_bar .set_postfix_str (disp_str , refresh = False )
454454 p_bar .update (1 )
455455
456456 # Store and compute validation loss, if specified
@@ -558,7 +558,9 @@ def train_offline(
558558
559559 # Loop through epochs
560560 for ep in range (1 , epochs + 1 ):
561- with tqdm (total = data_set .num_batches , desc = "Training epoch {}" .format (ep )) as p_bar :
561+ with tqdm (
562+ total = data_set .num_batches , desc = "Training epoch {}" .format (ep ), mininterval = TQDM_MININTERVAL
563+ ) as p_bar :
562564 # Loop through dataset
563565 for bi , forward_dict in enumerate (data_set , start = 1 ):
564566 # Perform one training step and obtain current loss value
@@ -578,7 +580,7 @@ def train_offline(
578580 disp_str = format_loss_string (ep , bi , loss , avg_dict , lr = lr , it_str = "Batch" )
579581
580582 # Update progress
581- p_bar .set_postfix_str (disp_str )
583+ p_bar .set_postfix_str (disp_str , refresh = False )
582584 p_bar .update (1 )
583585
584586 # Store and compute validation loss, if specified
@@ -731,7 +733,7 @@ def train_from_presimulation(
731733 f"Loading a simulation file resulted in a { type (epoch_data )} . Must be a dictionary or a list!"
732734 )
733735
734- with tqdm (total = len (index_list ), desc = f"Training epoch { ep } " ) as p_bar :
736+ with tqdm (total = len (index_list ), desc = f"Training epoch { ep } " , mininterval = TQDM_MININTERVAL ) as p_bar :
735737 for it , index in enumerate (index_list , start = 1 ):
736738 # Perform one training step and obtain current loss value
737739 input_dict = self .configurator (epoch_data [index ])
@@ -756,7 +758,7 @@ def train_from_presimulation(
756758 disp_str = format_loss_string (ep , it , loss , avg_dict , lr = lr )
757759
758760 # Update progress bar
759- p_bar .set_postfix_str (disp_str )
761+ p_bar .set_postfix_str (disp_str , refresh = False )
760762 p_bar .update (1 )
761763
762764 # Store after each epoch, if specified
@@ -873,7 +875,7 @@ def train_experience_replay(
873875
874876 # Loop through epochs
875877 for ep in range (1 , epochs + 1 ):
876- with tqdm (total = iterations_per_epoch , desc = f"Training epoch { ep } " ) as p_bar :
878+ with tqdm (total = iterations_per_epoch , desc = f"Training epoch { ep } " , mininterval = TQDM_MININTERVAL ) as p_bar :
877879 for it in range (1 , iterations_per_epoch + 1 ):
878880 # Simulate a batch of data and store into buffer
879881 input_dict = self ._forward_inference (
@@ -900,7 +902,7 @@ def train_experience_replay(
900902 disp_str = format_loss_string (ep , it , loss , avg_dict , lr = lr )
901903
902904 # Update progress bar
903- p_bar .set_postfix_str (disp_str )
905+ p_bar .set_postfix_str (disp_str , refresh = False )
904906 p_bar .update (1 )
905907
906908 # Store and compute validation loss, if specified
@@ -1099,7 +1101,7 @@ def mmd_hypothesis_test(
10991101 num_reference = reference_summary .shape [0 ]
11001102
11011103 mmd_null_samples = np .empty (num_null_samples , dtype = np .float32 )
1102- for i in tqdm (range (num_null_samples )):
1104+ for i in tqdm (range (num_null_samples ), mininterval = TQDM_MININTERVAL ):
11031105 if bootstrap :
11041106 bootstrap_idx = np .random .randint (0 , num_reference , size = num_observed )
11051107 simulated_summary = tf .gather (reference_summary , bootstrap_idx , axis = 0 )
0 commit comments