diff --git a/detectron2/config/defaults.py b/detectron2/config/defaults.py index 506651730e..7af828b4d4 100644 --- a/detectron2/config/defaults.py +++ b/detectron2/config/defaults.py @@ -654,3 +654,7 @@ # Do not commit any configs into it. _C.GLOBAL = CN() _C.GLOBAL.HACK = 1.0 + +# Period (measured in iterations) for writing logs during training +_C.TRAINER = CN() +_C.TRAINER.LOG_PERIOD = 20 diff --git a/detectron2/engine/defaults.py b/detectron2/engine/defaults.py index 3dbcd86b75..84587534c2 100644 --- a/detectron2/engine/defaults.py +++ b/detectron2/engine/defaults.py @@ -496,7 +496,7 @@ def test_and_save_results(): if comm.is_main_process(): # Here the default print/log frequency of each writer is used. # run writers in the end, so that evaluation metrics are written - ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) + ret.append(hooks.PeriodicWriter(self.build_writers(), period=cfg.TRAINER.LOG_PERIOD)) return ret def build_writers(self): diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 0000000000..4f2e0229b2 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,39 @@ +import unittest +from detectron2.config import get_cfg +from detectron2.engine import DefaultTrainer +from detectron2.data import DatasetCatalog, MetadataCatalog + +class TestTrainer(unittest.TestCase): + def test_log_period(self): + cfg = get_cfg() + cfg.TRAINER.LOG_PERIOD = 1 + + # Add minimum required config for trainer initialization + cfg.DATASETS.TRAIN = ("dummy_dataset",) + cfg.MODEL.DEVICE = "cpu" + cfg.OUTPUT_DIR = "." + + + + def dummy_dataset(): + return [{ + "file_name": "dummy.jpg", + "height": 500, + "width": 500, + "image_id": 1, + "annotations": [{ + "bbox": [0, 0, 10, 10], + "bbox_mode": 0, # XYXY_ABS + "category_id": 0, + "iscrowd": 0, + }] + }] + + # Register a dummy dataset with one sample + DatasetCatalog.register("dummy_dataset", dummy_dataset) + MetadataCatalog.get("dummy_dataset").set(thing_classes=["dummy"]) + + trainer = DefaultTrainer(cfg) + hooks = trainer.build_hooks() + writer_hook = hooks[-1] # PeriodicWriter should be last hook + self.assertEqual(writer_hook._period, 1, "Log period from config not properly set")