Skip to content

Commit ff8caee

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add default_writers()
Summary: Allow this logic to be reused when not using DefaultTrainer Reviewed By: theschnitz Differential Revision: D26213772 fbshipit-source-id: 9c85412fb24323018143c9fe9d42e23562dbb2ff
1 parent 4b539e4 commit ff8caee

File tree

4 files changed

+42
-45
lines changed

4 files changed

+42
-45
lines changed

detectron2/engine/defaults.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import sys
1616
from collections import OrderedDict
17+
from typing import Optional
1718
import torch
1819
from fvcore.nn.precise_bn import get_bn_modules
1920
from torch.nn.parallel import DistributedDataParallel
@@ -43,7 +44,13 @@
4344
from . import hooks
4445
from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
4546

46-
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
47+
__all__ = [
48+
"default_argument_parser",
49+
"default_setup",
50+
"default_writers",
51+
"DefaultPredictor",
52+
"DefaultTrainer",
53+
]
4754

4855

4956
def default_argument_parser(epilog=None):
@@ -157,6 +164,27 @@ def default_setup(cfg, args):
157164
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
158165

159166

167+
def default_writers(output_dir: str, max_iter: Optional[int] = None):
168+
"""
169+
Build a list of :class:`EventWriter` to be used.
170+
It now consists of a :class:`CommonMetricPrinter`,
171+
:class:`TensorboardXWriter` and :class:`JSONWriter`.
172+
173+
Args:
174+
output_dir: directory to store JSON metrics and tensorboard events
175+
max_iter: the total number of iterations
176+
177+
Returns:
178+
list[EventWriter]: a list of :class:`EventWriter` objects.
179+
"""
180+
return [
181+
# It may not always print what you want to see, since it prints "common" metrics only.
182+
CommonMetricPrinter(max_iter),
183+
JSONWriter(os.path.join(output_dir, "metrics.json")),
184+
TensorboardXWriter(output_dir),
185+
]
186+
187+
160188
class DefaultPredictor:
161189
"""
162190
Create a simple end-to-end predictor with the given config that runs on
@@ -377,37 +405,21 @@ def test_and_save_results():
377405
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
378406

379407
if comm.is_main_process():
408+
# Here the default print/log frequency of each writer is used.
380409
# run writers in the end, so that evaluation metrics are written
381410
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
382411
return ret
383412

384413
def build_writers(self):
385414
"""
386-
Build a list of writers to be used. By default it contains
387-
writers that write metrics to the screen,
388-
a json file, and a tensorboard event file respectively.
415+
Build a list of writers to be used using :func:`default_writers()`.
389416
If you'd like a different list of writers, you can overwrite it in
390417
your trainer.
391418
392419
Returns:
393420
list[EventWriter]: a list of :class:`EventWriter` objects.
394-
395-
It is now implemented by:
396-
::
397-
return [
398-
CommonMetricPrinter(self.max_iter),
399-
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
400-
TensorboardXWriter(self.cfg.OUTPUT_DIR),
401-
]
402-
403421
"""
404-
# Here the default print/log frequency of each writer is used.
405-
return [
406-
# It may not always print what you want to see, since it prints "common" metrics only.
407-
CommonMetricPrinter(self.max_iter),
408-
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
409-
TensorboardXWriter(self.cfg.OUTPUT_DIR),
410-
]
422+
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
411423

412424
def train(self):
413425
"""

tests/test_engine.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import tempfile
66
import time
77
import unittest
8-
from unittest.mock import MagicMock
98
import torch
109
from torch import nn
1110

@@ -61,12 +60,12 @@ def test_writer_hooks(self):
6160
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
6261
json_file = os.path.join(d, "metrics.json")
6362
writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)]
64-
logger_info = writers[0].logger.info = MagicMock()
6563

6664
trainer.register_hooks(
6765
[hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)]
6866
)
69-
trainer.train(0, max_iter)
67+
with self.assertLogs(writers[0].logger) as logs:
68+
trainer.train(0, max_iter)
7069

7170
with open(json_file, "r") as f:
7271
data = [json.loads(line.strip()) for line in f]
@@ -75,12 +74,11 @@ def test_writer_hooks(self):
7574
self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!")
7675

7776
# test logged messages from CommonMetricPrinter
78-
all_logs = [str(x) for x in logger_info.call_args_list]
79-
self.assertEqual(len(all_logs), 3)
80-
for log, iter in zip(all_logs, [19, 39, 49]):
77+
self.assertEqual(len(logs.output), 3)
78+
for log, iter in zip(logs.output, [19, 39, 49]):
8179
self.assertIn(f"iter: {iter}", log)
8280

83-
self.assertIn("eta: 0:00:00", all_logs[-1], "Last ETA must be 0!")
81+
self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!")
8482

8583
@unittest.skipIf(os.environ.get("CI"), "Require COCO data.")
8684
def test_default_trainer(self):

tests/test_events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def testPrintETA(self):
5757

5858
with self.assertLogs("detectron2.utils.events") as logs:
5959
p1.write()
60-
assert "eta" in logs.output[0]
60+
self.assertIn("eta", logs.output[0])
6161

6262
with self.assertLogs("detectron2.utils.events") as logs:
6363
p2.write()
64-
assert "eta" not in logs.output[0]
64+
self.assertNotIn("eta", logs.output[0])

tools/plain_train_net.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
build_detection_test_loader,
3434
build_detection_train_loader,
3535
)
36-
from detectron2.engine import default_argument_parser, default_setup, launch
36+
from detectron2.engine import default_argument_parser, default_setup, default_writers, launch
3737
from detectron2.evaluation import (
3838
CityscapesInstanceEvaluator,
3939
CityscapesSemSegEvaluator,
@@ -48,12 +48,7 @@
4848
)
4949
from detectron2.modeling import build_model
5050
from detectron2.solver import build_lr_scheduler, build_optimizer
51-
from detectron2.utils.events import (
52-
CommonMetricPrinter,
53-
EventStorage,
54-
JSONWriter,
55-
TensorboardXWriter,
56-
)
51+
from detectron2.utils.events import EventStorage
5752

5853
logger = logging.getLogger("detectron2")
5954

@@ -138,15 +133,7 @@ def do_train(cfg, model, resume=False):
138133
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
139134
)
140135

141-
writers = (
142-
[
143-
CommonMetricPrinter(max_iter),
144-
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
145-
TensorboardXWriter(cfg.OUTPUT_DIR),
146-
]
147-
if comm.is_main_process()
148-
else []
149-
)
136+
writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []
150137

151138
# compared to "train_net.py", we do not support accurate timing and
152139
# precise BN here, because they are not trivial to implement in a small training loop

0 commit comments

Comments
 (0)