diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index db85bcd1adfaf..3ed492e45e93d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `MemoryFormat` callback to easily change the model memory format ([#15175](https://github.com/Lightning-AI/lightning/pull/17680)) + + - Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545)) diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index d0ffb7b6a990c..fc9e2a8bd4306 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -21,6 +21,7 @@ from lightning.pytorch.callbacks.lambda_function import LambdaCallback from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor +from lightning.pytorch.callbacks.memory_format import MemoryFormat from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.callbacks.model_summary import ModelSummary from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint @@ -47,6 +48,7 @@ "LambdaCallback", "LearningRateFinder", "LearningRateMonitor", + "MemoryFormat", "ModelCheckpoint", "ModelPruning", "ModelSummary", diff --git a/src/lightning/pytorch/callbacks/memory_format.py b/src/lightning/pytorch/callbacks/memory_format.py new file mode 100644 index 0000000000000..a5cb75fbe3a84 --- /dev/null +++ b/src/lightning/pytorch/callbacks/memory_format.py @@ -0,0 +1,80 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +MemoryFormat +=============== + +changes the model memory format +""" + +from collections.abc import MutableSequence +from typing import Any, Optional + +import torch + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.rank_zero import rank_zero_warn + + +class MemoryFormat(Callback): + """The `MemoryFormat` callback changes the model memory format to `torch.channels_last` before training starts and + returns the original when it ends. + + `_. + + Setting the memory format channels_last usually improves GPU utilization. + + Runs on setup, so it can set the memory format before the model is DDP wrapped. + + """ + + def __init__(self, memory_format: torch.memory_format = torch.channels_last, convert_input: bool = False): + self.memory_format = memory_format + self.convert_input = convert_input + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + if self.memory_format in ( + torch.channels_last, + torch.channels_last_3d, + ) and not self.has_layer_benefiting_from_channels_last(pl_module): + rank_zero_warn( + f"model does not have any layers benefiting from {self.memory_format} format", category=RuntimeWarning + ) + + pl_module.to(memory_format=self.memory_format) + + def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + pl_module.to(memory_format=torch.contiguous_format) + + def on_train_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + ) -> None: + if not self.convert_input: + return + + if not isinstance(batch, MutableSequence): + rank_zero_warn( + f"batch is not a MutableSequence, cannot convert input to {self.memory_format}", category=RuntimeWarning + ) + return + + for i, item in enumerate(batch): + if isinstance(item, torch.Tensor): + batch[i] = item.to(memory_format=self.memory_format) + + benefitial_layers = (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.Conv2d, torch.nn.Conv3d) + + def has_layer_benefiting_from_channels_last(self, model: torch.nn.Module) -> bool: + return any(isinstance(layer, self.benefitial_layers) for layer in model.modules()) diff --git a/tests/tests_pytorch/callbacks/test_memory_format.py b/tests/tests_pytorch/callbacks/test_memory_format.py new file mode 100644 index 0000000000000..bf746ba153d05 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_memory_format.py @@ -0,0 +1,76 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from unittest.mock import MagicMock + +import torch + +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import MemoryFormat + + +def test_memory_format_callback_setup(): + class DummyModule(LightningModule): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + model = DummyModule() + + # create a dummy Trainer + trainer = Trainer(max_epochs=1, devices=1) + + # create the callback + callback = MemoryFormat() + + # call the setup method + callback.setup(trainer, model) + + # check that the memory format is channels_last + assert model.conv.weight.is_contiguous(memory_format=torch.channels_last) + + +def test_memory_format_callback(): + # create a mock LightningModule + trainer = MagicMock() + pl_module = MagicMock() + + # create a MemoryFormat callback + memory_format_callback = MemoryFormat() + + # check that the callback sets the memory format correctly + memory_format_callback.setup(trainer=trainer, pl_module=pl_module) + assert pl_module.to.call_args[1]["memory_format"] == torch.channels_last + + # check that the callback resets the memory format correctly + memory_format_callback.teardown(trainer=trainer, pl_module=pl_module) + assert pl_module.to.call_args[1]["memory_format"] == torch.contiguous_format + + # check that the callback warns if the model doesn't have any layers benefiting from channels_last + pl_module.modules.return_value = [torch.nn.Linear(10, 10)] + with warnings.catch_warnings(record=True) as w: + memory_format_callback.setup(trainer=trainer, pl_module=pl_module) + assert len(w) == 1 + assert issubclass(w[-1].category, RuntimeWarning) + assert "model does not have any layers benefiting from" in str(w[-1].message) + + # check that the callback converts input tensors to channels_last format + memory_format_callback.convert_input = True + batch = [torch.randn(16, 3, 32, 32), torch.randn(16, 3, 32, 32)] + memory_format_callback.on_train_batch_start(trainer=trainer, pl_module=pl_module, batch=batch, batch_idx=0) + for item in batch: + assert item.is_contiguous(memory_format=torch.channels_last)