1+ import gc
2+ import os
3+ import psutil
4+ import pytest
5+ import torch
6+ from torch .utils .data import DataLoader , Dataset
7+
8+ import lightning .pytorch as pl
9+ from lightning .pytorch import Trainer
10+ from lightning .pytorch .demos .boring_classes import BoringModel
11+
12+
13+ class LargeMemoryDataset (Dataset ):
14+ def __init__ (self , size = 100 , data_size = 100000 ):
15+ self .size = size
16+ self .data_size = data_size
17+ self .data = [torch .randn (data_size ) for _ in range (size )]
18+
19+ def __len__ (self ):
20+ return self .size
21+
22+ def __getitem__ (self , idx ):
23+ return self .data [idx ]
24+
25+
26+ class MemoryTestModel (BoringModel ):
27+ def __init__ (self ):
28+ super ().__init__ ()
29+ self .predictions = []
30+
31+ def predict_step (self , batch , batch_idx ):
32+ # Simulate large memory usage
33+ result = batch * 2
34+ if not self .trainer .predict_loop .return_predictions :
35+ # Clear memory if not returning predictions
36+ gc .collect ()
37+ return result
38+
39+ def predict_dataloader (self ):
40+ return DataLoader (LargeMemoryDataset (), batch_size = 16 )
41+
42+
43+ def get_memory_usage ():
44+ process = psutil .Process (os .getpid ())
45+ return process .memory_info ().rss / 1024 / 1024 # Convert to MB
46+
47+
48+ @pytest .fixture (autouse = True )
49+ def cleanup_env ():
50+ """Clean up environment variables after each test."""
51+ env_backup = os .environ .copy ()
52+ yield
53+ # Clean up environment variables
54+ os .environ .clear ()
55+ os .environ .update (env_backup )
56+
57+
58+ @pytest .mark .parametrize ("return_predictions" , [True , False ])
59+ def test_prediction_memory_usage (return_predictions ):
60+ """Test that memory usage doesn't grow unbounded during prediction."""
61+ # Skip if running on TPU
62+ if os .environ .get ("TPU_ML_PLATFORM" ):
63+ pytest .skip ("Test not supported on TPU platform" )
64+
65+ model = MemoryTestModel ()
66+ trainer = Trainer (accelerator = "cpu" , devices = 1 , max_epochs = 1 )
67+
68+ # Get initial memory usage
69+ initial_memory = get_memory_usage ()
70+
71+ # Run prediction
72+ predictions = trainer .predict (model , return_predictions = return_predictions )
73+
74+ # Get final memory usage
75+ final_memory = get_memory_usage ()
76+
77+ # Calculate memory growth
78+ memory_growth = final_memory - initial_memory
79+
80+ # If return_predictions is False, memory growth should be minimal
81+ if not return_predictions :
82+ assert memory_growth < 500 , f"Memory growth { memory_growth } MB exceeds threshold"
83+ else :
84+ # With return_predictions=True, some memory growth is expected
85+ assert memory_growth > 0 , "Expected some memory growth with return_predictions=True"
86+
87+
88+ def test_prediction_memory_with_gc ():
89+ """Test that memory usage stays constant when using gc.collect()."""
90+ # Skip if running on TPU
91+ if os .environ .get ("TPU_ML_PLATFORM" ):
92+ pytest .skip ("Test not supported on TPU platform" )
93+
94+ model = MemoryTestModel ()
95+ trainer = Trainer (accelerator = "cpu" , devices = 1 , max_epochs = 1 )
96+
97+ # Get initial memory usage
98+ initial_memory = get_memory_usage ()
99+
100+ # Run prediction with gc.collect()
101+ trainer .predict (model , return_predictions = False )
102+
103+ # Get final memory usage
104+ final_memory = get_memory_usage ()
105+
106+ # Memory growth should be minimal
107+ memory_growth = final_memory - initial_memory
108+ assert memory_growth < 500 , f"Memory growth { memory_growth } MB exceeds threshold"
0 commit comments