diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index dcfd873a28b4b..45a8c89f8f659 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -15,11 +15,9 @@ from collections.abc import Iterator from typing import Any, Optional, Union -import torch from lightning_utilities import WarningCache import lightning.pytorch as pl -from lightning.fabric.utilities import move_data_to_device from lightning.pytorch.callbacks import BasePredictionWriter from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.loops.loop import _Loop @@ -247,32 +245,29 @@ def _predict_step( self.batch_progress.increment_started() # configure step_kwargs - step_args = ( - self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step") - if not using_dataloader_iter - else (dataloader_iter,) - ) - predictions = call._call_strategy_hook(trainer, "predict_step", *step_args) - if predictions is None: - self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") + step_args = self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step") + step_output = call._call_lightning_module_hook(trainer, "predict_step", *step_args) self.batch_progress.increment_processed() - if using_dataloader_iter: - # update the hook kwargs now that the step method might have consumed the iterator - batch = data_fetcher._batch - batch_idx = data_fetcher._batch_idx - dataloader_idx = data_fetcher._dataloader_idx - hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) + # track batch indices for prediction writer + if not using_dataloader_iter and any_on_epoch: + self.current_batch_indices = self._get_batch_indices(data_fetcher.current_dataloader) + + # track predictions if needed + if self.return_predictions: + self._predictions[dataloader_idx].append(step_output) + else: + # Clear memory if not returning predictions + import gc + + gc.collect() - call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) - call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) + call._call_callback_hooks(trainer, "on_predict_batch_end", step_output, *hook_kwargs.values()) + call._call_lightning_module_hook(trainer, "on_predict_batch_end", step_output, *hook_kwargs.values()) self.batch_progress.increment_completed() - if self._return_predictions or any_on_epoch: - self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict: """Assembles the keyword arguments for the ``predict_step`` diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..175d2f3d03ded 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -95,12 +95,16 @@ def restore_env_variables(): "TF_GRPC_DEFAULT_OPTIONS", "XLA_FLAGS", "TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile - # TensorFlow and TPU related variables - "TF2_BEHAVIOR", - "TPU_ML_PLATFORM", - "TPU_ML_PLATFORM_VERSION", + # Memory leak test related + "PYTORCH_CUDA_ALLOC_CONF", # PyTorch memory allocator config + "CUDA_VISIBLE_DEVICES", # GPU visibility + "PYTORCH_NO_CUDA_MEMORY_CACHING", # Disable CUDA memory caching + # TensorFlow and TPU related + "ENABLE_RUNTIME_UPTIME_TELEMETRY", # TensorFlow telemetry + "TF2_BEHAVIOR", # TensorFlow 2.x behavior flag + "TPU_ML_PLATFORM", # TPU platform configuration + "TPU_ML_PLATFORM_VERSION", # TPU platform version "LD_LIBRARY_PATH", - "ENABLE_RUNTIME_UPTIME_TELEMETRY", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/tests_pytorch/trainer/test_memory_leak.py b/tests/tests_pytorch/trainer/test_memory_leak.py new file mode 100644 index 0000000000000..3aba287ef780a --- /dev/null +++ b/tests/tests_pytorch/trainer/test_memory_leak.py @@ -0,0 +1,81 @@ +import os + +import psutil +import pytest +import torch +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + + +class CustomModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(1000, 2) # Changed to match LargeDataset dim=1000 + + def forward(self, x): + return self.layer(x) + + +class LargeDataset(Dataset): + def __init__(self, size=1000, dim=1000): + self.data = torch.randn(size, dim) + self.targets = torch.randint(0, 10, (size,)) + + def __len__(self): + return len(self.data) + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + def __getitem__(self, idx): + # During prediction, return only the input tensor + if hasattr(self, "prediction_mode") and self.prediction_mode: + return self.data[idx] + return self.data[idx], self.targets[idx] + + def set_prediction_mode(self, mode=True): + self.prediction_mode = mode + + +def get_memory_usage(): + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024 / 1024 # MB + + +@pytest.mark.parametrize("return_predictions", [True, False]) +def test_prediction_memory_leak(tmp_path, return_predictions): + """Test that memory usage doesn't grow during prediction when return_predictions=False.""" + # Create a model and dataset + model = CustomModel() + dataset = LargeDataset() + dataset.set_prediction_mode(True) # Set prediction mode + dataloader = DataLoader(dataset, batch_size=32) + + # Get initial memory usage + initial_memory = get_memory_usage() + + # Run prediction + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=1, + max_epochs=1, + ) + + trainer.predict(model, dataloaders=dataloader, return_predictions=return_predictions) + + # Get final memory usage + final_memory = get_memory_usage() + + # Calculate memory growth + memory_growth = final_memory - initial_memory + + # When return_predictions=False, memory growth should be minimal + if not return_predictions: + assert memory_growth < 100, f"Memory growth {memory_growth}MB is too high when return_predictions=False" + else: + # When return_predictions=True, we expect some memory growth due to storing predictions + assert memory_growth > 0, "Expected memory growth when storing predictions"