Skip to content

Commit ce0897c

Browse files
committed
memory leak test
1 parent 65d38fe commit ce0897c

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
import psutil
3+
import pytest
4+
import torch
5+
from torch.utils.data import DataLoader, Dataset
6+
7+
from lightning.pytorch import Trainer
8+
from lightning.pytorch.demos.boring_classes import BoringModel
9+
10+
11+
class CustomModel(BoringModel):
12+
def __init__(self):
13+
super().__init__()
14+
self.layer = torch.nn.Linear(1000, 2) # Changed to match LargeDataset dim=1000
15+
16+
def forward(self, x):
17+
return self.layer(x)
18+
19+
20+
class LargeDataset(Dataset):
21+
def __init__(self, size=1000, dim=1000):
22+
self.data = torch.randn(size, dim)
23+
self.targets = torch.randint(0, 10, (size,))
24+
25+
def __len__(self):
26+
return len(self.data)
27+
28+
def __getitem__(self, idx):
29+
return self.data[idx], self.targets[idx]
30+
31+
def __iter__(self):
32+
for i in range(len(self)):
33+
yield self[i]
34+
35+
def __getitem__(self, idx):
36+
# During prediction, return only the input tensor
37+
if hasattr(self, 'prediction_mode') and self.prediction_mode:
38+
return self.data[idx]
39+
return self.data[idx], self.targets[idx]
40+
41+
def set_prediction_mode(self, mode=True):
42+
self.prediction_mode = mode
43+
44+
45+
def get_memory_usage():
46+
process = psutil.Process(os.getpid())
47+
return process.memory_info().rss / 1024 / 1024 # MB
48+
49+
50+
@pytest.mark.parametrize("return_predictions", [True, False])
51+
def test_prediction_memory_leak(tmp_path, return_predictions):
52+
"""Test that memory usage doesn't grow during prediction when return_predictions=False."""
53+
# Create a model and dataset
54+
model = CustomModel()
55+
dataset = LargeDataset()
56+
dataset.set_prediction_mode(True) # Set prediction mode
57+
dataloader = DataLoader(dataset, batch_size=32)
58+
59+
# Get initial memory usage
60+
initial_memory = get_memory_usage()
61+
62+
# Run prediction
63+
trainer = Trainer(
64+
default_root_dir=tmp_path,
65+
accelerator="cpu",
66+
devices=1,
67+
max_epochs=1,
68+
)
69+
70+
predictions = trainer.predict(model, dataloaders=dataloader, return_predictions=return_predictions)
71+
72+
# Get final memory usage
73+
final_memory = get_memory_usage()
74+
75+
# Calculate memory growth
76+
memory_growth = final_memory - initial_memory
77+
78+
# When return_predictions=False, memory growth should be minimal
79+
if not return_predictions:
80+
assert memory_growth < 100, f"Memory growth {memory_growth}MB is too high when return_predictions=False"
81+
else:
82+
# When return_predictions=True, we expect some memory growth due to storing predictions
83+
assert memory_growth > 0, "Expected memory growth when storing predictions"

0 commit comments

Comments
 (0)