Skip to content

Commit 403b3ae

Browse files
committed
fix: Add memory leak prevention in prediction loop
1 parent 8055717 commit 403b3ae

File tree

2 files changed

+123
-19
lines changed

2 files changed

+123
-19
lines changed

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -247,31 +247,27 @@ def _predict_step(
247247
self.batch_progress.increment_started()
248248

249249
# configure step_kwargs
250-
step_args = (
251-
self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step")
252-
if not using_dataloader_iter
253-
else (dataloader_iter,)
254-
)
255-
predictions = call._call_strategy_hook(trainer, "predict_step", *step_args)
256-
if predictions is None:
257-
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")
250+
step_args = self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step")
251+
step_output = call._call_lightning_module_hook(trainer, "predict_step", *step_args)
258252

259253
self.batch_progress.increment_processed()
260254

261-
if using_dataloader_iter:
262-
# update the hook kwargs now that the step method might have consumed the iterator
263-
batch = data_fetcher._batch
264-
batch_idx = data_fetcher._batch_idx
265-
dataloader_idx = data_fetcher._dataloader_idx
266-
hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None)
255+
# track batch indices for prediction writer
256+
if not using_dataloader_iter and any_on_epoch:
257+
self.current_batch_indices = self._get_batch_indices(data_fetcher.current_dataloader)
267258

268-
call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
269-
call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
259+
# track predictions if needed
260+
if self.return_predictions:
261+
self._predictions[dataloader_idx].append(step_output)
262+
else:
263+
# Clear memory if not returning predictions
264+
import gc
265+
gc.collect()
270266

271-
self.batch_progress.increment_completed()
267+
call._call_callback_hooks(trainer, "on_predict_batch_end", step_output, *hook_kwargs.values())
268+
call._call_lightning_module_hook(trainer, "on_predict_batch_end", step_output, *hook_kwargs.values())
272269

273-
if self._return_predictions or any_on_epoch:
274-
self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu")))
270+
self.batch_progress.increment_completed()
275271

276272
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict:
277273
"""Assembles the keyword arguments for the ``predict_step``
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import gc
2+
import os
3+
import psutil
4+
import pytest
5+
import torch
6+
from torch.utils.data import DataLoader, Dataset
7+
8+
import lightning.pytorch as pl
9+
from lightning.pytorch import Trainer
10+
from lightning.pytorch.demos.boring_classes import BoringModel
11+
12+
13+
class LargeMemoryDataset(Dataset):
14+
def __init__(self, size=100, data_size=100000):
15+
self.size = size
16+
self.data_size = data_size
17+
self.data = [torch.randn(data_size) for _ in range(size)]
18+
19+
def __len__(self):
20+
return self.size
21+
22+
def __getitem__(self, idx):
23+
return self.data[idx]
24+
25+
26+
class MemoryTestModel(BoringModel):
27+
def __init__(self):
28+
super().__init__()
29+
self.predictions = []
30+
31+
def predict_step(self, batch, batch_idx):
32+
# Simulate large memory usage
33+
result = batch * 2
34+
if not self.trainer.predict_loop.return_predictions:
35+
# Clear memory if not returning predictions
36+
gc.collect()
37+
return result
38+
39+
def predict_dataloader(self):
40+
return DataLoader(LargeMemoryDataset(), batch_size=16)
41+
42+
43+
def get_memory_usage():
44+
process = psutil.Process(os.getpid())
45+
return process.memory_info().rss / 1024 / 1024 # Convert to MB
46+
47+
48+
@pytest.fixture(autouse=True)
49+
def cleanup_env():
50+
"""Clean up environment variables after each test."""
51+
env_backup = os.environ.copy()
52+
yield
53+
# Clean up environment variables
54+
os.environ.clear()
55+
os.environ.update(env_backup)
56+
57+
58+
@pytest.mark.parametrize("return_predictions", [True, False])
59+
def test_prediction_memory_usage(return_predictions):
60+
"""Test that memory usage doesn't grow unbounded during prediction."""
61+
# Skip if running on TPU
62+
if os.environ.get("TPU_ML_PLATFORM"):
63+
pytest.skip("Test not supported on TPU platform")
64+
65+
model = MemoryTestModel()
66+
trainer = Trainer(accelerator="cpu", devices=1, max_epochs=1)
67+
68+
# Get initial memory usage
69+
initial_memory = get_memory_usage()
70+
71+
# Run prediction
72+
predictions = trainer.predict(model, return_predictions=return_predictions)
73+
74+
# Get final memory usage
75+
final_memory = get_memory_usage()
76+
77+
# Calculate memory growth
78+
memory_growth = final_memory - initial_memory
79+
80+
# If return_predictions is False, memory growth should be minimal
81+
if not return_predictions:
82+
assert memory_growth < 500, f"Memory growth {memory_growth}MB exceeds threshold"
83+
else:
84+
# With return_predictions=True, some memory growth is expected
85+
assert memory_growth > 0, "Expected some memory growth with return_predictions=True"
86+
87+
88+
def test_prediction_memory_with_gc():
89+
"""Test that memory usage stays constant when using gc.collect()."""
90+
# Skip if running on TPU
91+
if os.environ.get("TPU_ML_PLATFORM"):
92+
pytest.skip("Test not supported on TPU platform")
93+
94+
model = MemoryTestModel()
95+
trainer = Trainer(accelerator="cpu", devices=1, max_epochs=1)
96+
97+
# Get initial memory usage
98+
initial_memory = get_memory_usage()
99+
100+
# Run prediction with gc.collect()
101+
trainer.predict(model, return_predictions=False)
102+
103+
# Get final memory usage
104+
final_memory = get_memory_usage()
105+
106+
# Memory growth should be minimal
107+
memory_growth = final_memory - initial_memory
108+
assert memory_growth < 500, f"Memory growth {memory_growth}MB exceeds threshold"

0 commit comments

Comments
 (0)