-
I encountered a situation where memory usage increases linearly during training, although all tensor variables are overwritten or even explicitly deleted regularly. This occurs when passing an input batch through some layers (e.g. MLP), and then replacing part of the input batch with the output of the forward pass. Can someone help me figure out why this is happening, and how I can achieve the same result without the memory growing throughout the training? The Thanks in advance for any hints! #!/usr/bin/env python3
import torch
from torch import nn
import torch_geometric
import tracemalloc
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear_layers = torch_geometric.nn.MLP([300, 32, 7])
def forward(self, batch):
batch_ptr = 1 + batch.ptr[:-1]
batch_ptr = batch_ptr.tolist()
new_features = self.linear_layers(batch.x)
batch.x[batch_ptr, :7] = new_features[batch_ptr, :]
return batch
def main():
data_list = list()
batch_size = 100
num_features = 300
for i in range(batch_size):
data_list.append(
torch_geometric.data.data.Data(x=torch.randn(12, num_features))
)
batch = torch_geometric.data.batch.Batch.from_data_list(data_list)
mdl = Model()
tracemalloc.start()
for i in range(5000):
batch_out = mdl.forward(batch.detach())
if i % 1000 == 0:
print(f"Memory ({i}): {tracemalloc.get_traced_memory()}")
del batch_out
tracemalloc.stop()
if __name__ == "__main__":
main() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
That's because you are modifying
|
Beta Was this translation helpful? Give feedback.
That's because you are modifying
batch
in-place, so the computation graph is never freed. Note that although youdetach()
thebatch
object, the old features are still kept in memory for backpropagation. This resolves the issue: