Is it possible to make TensorBoardStatsHandler
log loss
per epoch rather than per iteration?
#5339
-
Hi! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
@yhuang1997 One simple solution is to pass your own custom The code would be like this:
But notice that the given sample would only log the loss at the moment of the epoch completed, not the average loss of one epoch. |
Beta Was this translation helpful? Give feedback.
-
Hi, @KumoLiu Here is my solution. Maybe it's not robust enough, so please use it carefully, but it indeed handles the issue I face.
def _default_epoch_writer(self, engine: Engine, writer) -> None:
current_epoch = self.global_epoch_transform(engine.state.epoch)
##################new###############
if self.handler is not None:
if isinstance(self.handler, MetricLogger):
mean_epoch_loss = torch.tensor(self.handler.loss)[:,1].mean()
self._write_scalar(
_engine=engine,
writer=writer,
tag=self.tag_name,
value=mean_epoch_loss.item() if isinstance(mean_epoch_loss, torch.Tensor) else mean_epoch_loss,
step=current_epoch,
)
# clear handler cache
self.handler.loss = []
else:
raise NotImplementedError
##################new###############
summary_dict = engine.state.metrics
for name, value in summary_dict.items():
if is_scalar(value):
self._write_scalar(engine, writer, name, value, current_epoch)
if self.state_attributes is not None:
for attr in self.state_attributes:
self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch)
writer.flush() def bind_handler(self, handler):
self.handler = handler
metric_logger_handler = MetricLogger(loss_transform=lambda x: x["loss"])
tensorboard_state_handler = TensorBoardStatsHandler(log_dir=tensorboard_filename,
output_transform=lambda x: x["loss"],
iteration_log=False, # Please set it to disable log per iteration.
)
tensorboard_state_handler.bind_handler(metric_logger_handler)
train_handlers = [...,metric_logger_handler, tensorboard_state_handler ]
trainer = Trainer(...., handlers=train_handlers) |
Beta Was this translation helpful? Give feedback.
Hi, @KumoLiu
Thanks again for your solution. You enlightened me a lot, but it's equal to just logging the loss of the last batch in each epoch. So, I followed what you have done, and tried to use
monai.handlers.MetricLogger
&monai.handlers.TensorBoardStatsHandler
to logmean_loss
in each epoch. The idea is to useMetricLogger
to record loss in each iteration for an epoch and log it byTensorBoardStatsHandler.
Here is my solution. Maybe it's not robust enough, so please use it carefully, but it indeed handles the issue I face.
_default_epoch_writer
and addbind_handler
in ClassTensorBoardStatsHandler