66import time
77from collections import defaultdict
88from contextlib import contextmanager
9+ from typing import Optional
910from fvcore .common .history_buffer import HistoryBuffer
1011
1112from detectron2 .utils .file_io import PathManager
@@ -186,15 +187,35 @@ class CommonMetricPrinter(EventWriter):
186187 To print something in more customized ways, please implement a similar printer by yourself.
187188 """
188189
189- def __init__ (self , max_iter ):
190+ def __init__ (self , max_iter : Optional [ int ] = None ):
190191 """
191192 Args:
192- max_iter (int) : the maximum number of iterations to train.
193- Used to compute ETA.
193+ max_iter: the maximum number of iterations to train.
194+ Used to compute ETA. If not given, ETA will not be printed.
194195 """
195196 self .logger = logging .getLogger (__name__ )
196197 self ._max_iter = max_iter
197- self ._last_write = None
198+ self ._last_write = None # (step, time) of last call to write(). Used to compute ETA
199+
200+ def _get_eta (self , storage ) -> Optional [str ]:
201+ if self ._max_iter is None :
202+ return ""
203+ iteration = storage .iter
204+ try :
205+ eta_seconds = storage .history ("time" ).median (1000 ) * (self ._max_iter - iteration - 1 )
206+ storage .put_scalar ("eta_seconds" , eta_seconds , smoothing_hint = False )
207+ return str (datetime .timedelta (seconds = int (eta_seconds )))
208+ except KeyError :
209+ # estimate eta on our own - more noisy
210+ eta_string = None
211+ if self ._last_write is not None :
212+ estimate_iter_time = (time .perf_counter () - self ._last_write [1 ]) / (
213+ iteration - self ._last_write [0 ]
214+ )
215+ eta_seconds = estimate_iter_time * (self ._max_iter - iteration - 1 )
216+ eta_string = str (datetime .timedelta (seconds = int (eta_seconds )))
217+ self ._last_write = (iteration , time .perf_counter ())
218+ return eta_string
198219
199220 def write (self ):
200221 import torch
@@ -213,29 +234,17 @@ def write(self):
213234 # they may not exist in the first few iterations (due to warmup)
214235 # or when SimpleTrainer is not used
215236 data_time = None
216-
217- eta_string = None
218237 try :
219238 iter_time = storage .history ("time" ).global_avg ()
220- eta_seconds = storage .history ("time" ).median (1000 ) * (self ._max_iter - iteration - 1 )
221- storage .put_scalar ("eta_seconds" , eta_seconds , smoothing_hint = False )
222- eta_string = str (datetime .timedelta (seconds = int (eta_seconds )))
223239 except KeyError :
224240 iter_time = None
225- # estimate eta on our own - more noisy
226- if self ._last_write is not None :
227- estimate_iter_time = (time .perf_counter () - self ._last_write [1 ]) / (
228- iteration - self ._last_write [0 ]
229- )
230- eta_seconds = estimate_iter_time * (self ._max_iter - iteration - 1 )
231- eta_string = str (datetime .timedelta (seconds = int (eta_seconds )))
232- self ._last_write = (iteration , time .perf_counter ())
233-
234241 try :
235242 lr = "{:.5g}" .format (storage .history ("lr" ).latest ())
236243 except KeyError :
237244 lr = "N/A"
238245
246+ eta_string = self ._get_eta (storage )
247+
239248 if torch .cuda .is_available ():
240249 max_mem_mb = torch .cuda .max_memory_allocated () / 1024.0 / 1024.0
241250 else :
0 commit comments