2121#
2222"""Utility functions for solvers and their training loops."""
2323
24- import logging
2524from collections .abc import Iterable
2625from typing import Dict
2726
3029
3130
3231def _description (stats : Dict [str , float ]):
33- stats_str = [f"{ key } : { value :.3f } " for key , value in stats .items ()]
32+ stats_str = [f"{ key } : { value : .4f } " for key , value in stats .items ()]
3433 return " " .join (stats_str )
3534
3635
@@ -74,9 +73,7 @@ class ProgressBar:
7473 "Log and display values during training."
7574
7675 loader : Iterable
77- logger : logging .Logger = None
78- log_format : str = None
79- log_frequency : int = None
76+ log_format : str
8077
8178 _valid_formats = ["tqdm" , "off" ]
8279
@@ -90,23 +87,13 @@ def __post_init__(self):
9087 raise ValueError (
9188 f"log_format must be one of { self ._valid_formats } , "
9289 f"but got { self .log_formats } " )
93- self ._stats = None
9490
9591 def __iter__ (self ):
9692 self .iterator = self .loader
9793 if self .use_tqdm :
9894 self .iterator = tqdm .tqdm (self .iterator )
9995 for num_batch , batch in enumerate (self .iterator ):
10096 yield num_batch , batch
101- self ._log_message (num_batch , self ._stats )
102- self ._log_message (num_batch , self ._stats )
103-
104- def _log_message (self , num_steps , stats ):
105- if self .logger is None :
106- return
107- if num_steps % self .log_frequency != 0 :
108- return
109- self .logger .info (f"Train: Step { num_steps } { _description (stats )} " )
11097
11198 def set_description (self , stats : Dict [str , float ]):
11299 """Update the progress bar description.
@@ -119,5 +106,3 @@ def set_description(self, stats: Dict[str, float]):
119106 """
120107 if self .use_tqdm :
121108 self .iterator .set_description (_description (stats ))
122-
123- self ._stats = stats
0 commit comments