Skip to content

Commit 81b9cad

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
allow CommonMetricPrinter without max_iter
Summary: It will then not print ETA Reviewed By: theschnitz Differential Revision: D26205860 fbshipit-source-id: 7f9f26484e349461f85a91a0212d36af546dcb52
1 parent 45a8bfb commit 81b9cad

File tree

2 files changed

+46
-19
lines changed

2 files changed

+46
-19
lines changed

detectron2/utils/events.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from collections import defaultdict
88
from contextlib import contextmanager
9+
from typing import Optional
910
from fvcore.common.history_buffer import HistoryBuffer
1011

1112
from detectron2.utils.file_io import PathManager
@@ -186,15 +187,35 @@ class CommonMetricPrinter(EventWriter):
186187
To print something in more customized ways, please implement a similar printer by yourself.
187188
"""
188189

189-
def __init__(self, max_iter):
190+
def __init__(self, max_iter: Optional[int] = None):
190191
"""
191192
Args:
192-
max_iter (int): the maximum number of iterations to train.
193-
Used to compute ETA.
193+
max_iter: the maximum number of iterations to train.
194+
Used to compute ETA. If not given, ETA will not be printed.
194195
"""
195196
self.logger = logging.getLogger(__name__)
196197
self._max_iter = max_iter
197-
self._last_write = None
198+
self._last_write = None # (step, time) of last call to write(). Used to compute ETA
199+
200+
def _get_eta(self, storage) -> Optional[str]:
201+
if self._max_iter is None:
202+
return ""
203+
iteration = storage.iter
204+
try:
205+
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
206+
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
207+
return str(datetime.timedelta(seconds=int(eta_seconds)))
208+
except KeyError:
209+
# estimate eta on our own - more noisy
210+
eta_string = None
211+
if self._last_write is not None:
212+
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
213+
iteration - self._last_write[0]
214+
)
215+
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
216+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
217+
self._last_write = (iteration, time.perf_counter())
218+
return eta_string
198219

199220
def write(self):
200221
import torch
@@ -213,29 +234,17 @@ def write(self):
213234
# they may not exist in the first few iterations (due to warmup)
214235
# or when SimpleTrainer is not used
215236
data_time = None
216-
217-
eta_string = None
218237
try:
219238
iter_time = storage.history("time").global_avg()
220-
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
221-
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
222-
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
223239
except KeyError:
224240
iter_time = None
225-
# estimate eta on our own - more noisy
226-
if self._last_write is not None:
227-
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
228-
iteration - self._last_write[0]
229-
)
230-
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
231-
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
232-
self._last_write = (iteration, time.perf_counter())
233-
234241
try:
235242
lr = "{:.5g}".format(storage.history("lr").latest())
236243
except KeyError:
237244
lr = "N/A"
238245

246+
eta_string = self._get_eta(storage)
247+
239248
if torch.cuda.is_available():
240249
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
241250
else:

tests/test_events.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tempfile
55
import unittest
66

7-
from detectron2.utils.events import EventStorage, JSONWriter
7+
from detectron2.utils.events import CommonMetricPrinter, EventStorage, JSONWriter
88

99

1010
class TestEventWriter(unittest.TestCase):
@@ -44,3 +44,21 @@ def testScalarMismatchedPeriod(self):
4444
self.assertTrue([int(k.get("key2", 0)) for k in data] == [17, 0, 34, 0, 51, 0])
4545
self.assertTrue([int(k.get("key", 0)) for k in data] == [0, 19, 0, 39, 0, 59])
4646
self.assertTrue([int(k["iteration"]) for k in data] == [17, 19, 34, 39, 51, 59])
47+
48+
def testPrintETA(self):
49+
with EventStorage() as s:
50+
p1 = CommonMetricPrinter(10)
51+
p2 = CommonMetricPrinter()
52+
53+
s.put_scalar("time", 1.0)
54+
s.step()
55+
s.put_scalar("time", 1.0)
56+
s.step()
57+
58+
with self.assertLogs("detectron2.utils.events") as logs:
59+
p1.write()
60+
assert "eta" in logs.output[0]
61+
62+
with self.assertLogs("detectron2.utils.events") as logs:
63+
p2.write()
64+
assert "eta" not in logs.output[0]

0 commit comments

Comments
 (0)