-
Couldn't load subscription status.
- Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
The gradient checkpointing is enabled via self.training, but then the log_validations also unnecessarily encounter this codepath.
I found that I have to disable this when running even under torch.no_grad(), looked and saw that the official examples do not do this either.
This gives a substantial performance boost more similar to a normal inference script running outside of a training loop.
Reproduction
Add print statements to the checkpointing function.
Logs
No response
System Info
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates