1+ import os
2+ import psutil
3+ import pytest
4+ import torch
5+ from torch .utils .data import DataLoader , Dataset
6+
7+ from lightning .pytorch import Trainer
8+ from lightning .pytorch .demos .boring_classes import BoringModel
9+
10+
11+ class CustomModel (BoringModel ):
12+ def __init__ (self ):
13+ super ().__init__ ()
14+ self .layer = torch .nn .Linear (1000 , 2 ) # Changed to match LargeDataset dim=1000
15+
16+ def forward (self , x ):
17+ return self .layer (x )
18+
19+
20+ class LargeDataset (Dataset ):
21+ def __init__ (self , size = 1000 , dim = 1000 ):
22+ self .data = torch .randn (size , dim )
23+ self .targets = torch .randint (0 , 10 , (size ,))
24+
25+ def __len__ (self ):
26+ return len (self .data )
27+
28+ def __getitem__ (self , idx ):
29+ return self .data [idx ], self .targets [idx ]
30+
31+ def __iter__ (self ):
32+ for i in range (len (self )):
33+ yield self [i ]
34+
35+ def __getitem__ (self , idx ):
36+ # During prediction, return only the input tensor
37+ if hasattr (self , 'prediction_mode' ) and self .prediction_mode :
38+ return self .data [idx ]
39+ return self .data [idx ], self .targets [idx ]
40+
41+ def set_prediction_mode (self , mode = True ):
42+ self .prediction_mode = mode
43+
44+
45+ def get_memory_usage ():
46+ process = psutil .Process (os .getpid ())
47+ return process .memory_info ().rss / 1024 / 1024 # MB
48+
49+
50+ @pytest .mark .parametrize ("return_predictions" , [True , False ])
51+ def test_prediction_memory_leak (tmp_path , return_predictions ):
52+ """Test that memory usage doesn't grow during prediction when return_predictions=False."""
53+ # Create a model and dataset
54+ model = CustomModel ()
55+ dataset = LargeDataset ()
56+ dataset .set_prediction_mode (True ) # Set prediction mode
57+ dataloader = DataLoader (dataset , batch_size = 32 )
58+
59+ # Get initial memory usage
60+ initial_memory = get_memory_usage ()
61+
62+ # Run prediction
63+ trainer = Trainer (
64+ default_root_dir = tmp_path ,
65+ accelerator = "cpu" ,
66+ devices = 1 ,
67+ max_epochs = 1 ,
68+ )
69+
70+ predictions = trainer .predict (model , dataloaders = dataloader , return_predictions = return_predictions )
71+
72+ # Get final memory usage
73+ final_memory = get_memory_usage ()
74+
75+ # Calculate memory growth
76+ memory_growth = final_memory - initial_memory
77+
78+ # When return_predictions=False, memory growth should be minimal
79+ if not return_predictions :
80+ assert memory_growth < 100 , f"Memory growth { memory_growth } MB is too high when return_predictions=False"
81+ else :
82+ # When return_predictions=True, we expect some memory growth due to storing predictions
83+ assert memory_growth > 0 , "Expected memory growth when storing predictions"
0 commit comments