Skip to content

Logging RL results and tracking them with ModelCheckpoint(monitor=...) #4584

@TrentBrick

Description

@TrentBrick

I am using Pytorch Lightning in an RL setting and want to save a model when it hits a new max average reward. I am using the Tensorboard logger where I return my neural network loss in the training_step() using:

logs = {"policy_loss": pred_loss}
return {'loss':pred_loss, 'log':logs}

And then I am saving my RL environment rewards using in on_epoch_end():

self.logger.experiment.add_scalar("mean_reward", np.mean(reward_losses), self.global_step)
self.logger.experiment.add_scalars('rollout_stats', {"std_reward":np.std(reward_losses),
                "max_reward":np.max(reward_losses), "min_reward":np.min(reward_losses)}, self.global_step)

And every 5 epochs I am also writing out another RL reward loss where I use the best actions rather than sampling from them:

if self.current_epoch % self.hparams['eval_every']==0 and self.logger:
            output = self.collect_rollouts(greedy=True, num_episodes=self.hparams['eval_episodes'])
            reward_losses = output[0]
            self.logger.experiment.add_scalar("eval_mean", np.mean(reward_losses), self.global_step)

My question is, how can I set my ModelCheckpoint to monitor eval_mean (which is only written out every 5 epochs, this seems like it would be a problem)? I would also settle for monitoring mean_reward (written out every epoch)? Right now I can only successfully monitor policy_loss which does not always correspond to higher rewards obtained (setting monitor = to anything else throws an error).

I know that in the new PL version self.log() should be used but after re-writing my code using this it still didn't solve my issue.

I have spent a lot of time looking through the docs and for examples of this but I have found the logging docs on this to be quite sparse and difficult to even get everything to log in the first place.

I am using Pytorch Lightning 1.0.5 and Pytorch 1.7.0.

Thank you for any help/guidance.

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions