GPU memory usage increased with the use of pl.metrics #6612
Unanswered
skyshine102
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment
-
ok, I made a stupid mistake. My "training_epoch_end" hook is wrong, so problem (2) has been fixed. But (1) remains. I still don't know why. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I was training a simple ResNet18 model with pytorch lightening and the pl.metrics class. I used wandb to track the gpu memory usage and found sth. weird. See the figure below:

The above figure has two curves:
We can see (1) gpu memory usage increased after the first epoch in both curves. (2) With pl.metrics (I used MetricCollection in my implementation), the gpu memory usage increases with more training steps.
I did realized that after v1.2, the metric objects will not clear global state between epochs. So I called
self.metric.reset()
to prevent state accumulation, but still see this memory usage increase phenomenon. The code is attached: https://gist.github.com/skyshine102/83643e5499b780433cb0cdd617c4857dPytorch-lightening version: 1.2.1
Does anyone know how can (1),(2) happen?
Beta Was this translation helpful? Give feedback.
All reactions