Skip to content

Commit d6e8c13

Browse files
Tao Xufacebook-github-bot
authored andcommitted
refactor GANs trainer
Summary: Refactor and simplify the trainer by reusing the _write_metrics defined in SimpleTrainer. Reviewed By: newstzpz Differential Revision: D26038293 fbshipit-source-id: 174c386375759ce761ff2db80cb5a076a083bcc8
1 parent 25e283a commit d6e8c13

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

detectron2/engine/train_loop.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,12 @@ def run_step(self):
244244
"""
245245
self.optimizer.step()
246246

247-
def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float):
247+
def _write_metrics(
248+
self,
249+
loss_dict: Dict[str, torch.Tensor],
250+
data_time: float,
251+
prefix: str = "",
252+
):
248253
"""
249254
Args:
250255
loss_dict (dict): dict of scalar losses
@@ -281,7 +286,7 @@ def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float):
281286
f"loss_dict = {metrics_dict}"
282287
)
283288

284-
storage.put_scalar("total_loss", total_losses_reduced)
289+
storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
285290
if len(metrics_dict) > 1:
286291
storage.put_scalars(**metrics_dict)
287292

0 commit comments

Comments
 (0)