Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fbe1bed
channels last callback
Pedrexus May 23, 2023
2be15bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2023
2e5daba
type hint fix
Pedrexus May 23, 2023
3365b5b
changelog updated
Pedrexus May 23, 2023
eef597b
test import error fix
Pedrexus May 23, 2023
896879a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2023
a28c22f
Merge branch 'master' into master
Borda May 29, 2023
97093d2
comments have been addressed
Pedrexus May 31, 2023
9f770ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
13ffbff
Merge branch 'master' into master
Pedrexus May 31, 2023
7a3ab21
Merge branch 'master' into master
Borda Aug 8, 2023
ff0dfe0
Merge branch 'master' into master
Borda Oct 12, 2023
3d3e2e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2023
54def19
Merge remote-tracking branch 'upstream/master'
Pedrexus Dec 13, 2023
f837d47
import fixes
Pedrexus Dec 13, 2023
03a252f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
10bd4f5
Merge branch 'master' into master
Pedrexus Dec 30, 2023
3f4c044
Merge branch 'master' into master
Borda Mar 12, 2025
dadb6c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2025
7607e0b
Merge branch 'master' into master
Borda Mar 12, 2025
f6b3468
Merge branch 'master' into master
Borda May 28, 2025
35d3f17
Merge branch 'master' into master
Borda Jun 11, 2025
2a20155
Merge branch 'master' into master
Borda Aug 8, 2025
d524717
Apply suggestions from code review
Borda Aug 8, 2025
059d2af
Merge branch 'master' into master
Borda Aug 15, 2025
84c0f23
Merge branch 'master' into master
Borda Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `MemoryFormat` callback to easily change the model memory format ([#15175](https://github.com/Lightning-AI/lightning/pull/17680))


- Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))


Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lightning.pytorch.callbacks.lambda_function import LambdaCallback
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor
from lightning.pytorch.callbacks.memory_format import MemoryFormat
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.callbacks.model_summary import ModelSummary
from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
Expand All @@ -47,6 +48,7 @@
"LambdaCallback",
"LearningRateFinder",
"LearningRateMonitor",
"MemoryFormat",
"ModelCheckpoint",
"ModelPruning",
"ModelSummary",
Expand Down
80 changes: 80 additions & 0 deletions src/lightning/pytorch/callbacks/memory_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
MemoryFormat
===============

changes the model memory format
"""

from collections.abc import MutableSequence
from typing import Any, Optional

import torch

import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_warn


class MemoryFormat(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pedrexus, just thinking if the name is right, since it rather does ReorderImageDimensions

"""The `MemoryFormat` callback changes the model memory format to `torch.channels_last` before training starts and
returns the original when it ends.

<https://\\pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_.

Setting the memory format channels_last usually improves GPU utilization.

Runs on setup, so it can set the memory format before the model is DDP wrapped.

"""

def __init__(self, memory_format: torch.memory_format = torch.channels_last, convert_input: bool = False):
self.memory_format = memory_format
self.convert_input = convert_input

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if self.memory_format in (
torch.channels_last,
torch.channels_last_3d,
) and not self.has_layer_benefiting_from_channels_last(pl_module):
rank_zero_warn(
f"model does not have any layers benefiting from {self.memory_format} format", category=RuntimeWarning
)

pl_module.to(memory_format=self.memory_format)

def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
pl_module.to(memory_format=torch.contiguous_format)

def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
if not self.convert_input:
return

if not isinstance(batch, MutableSequence):
rank_zero_warn(
f"batch is not a MutableSequence, cannot convert input to {self.memory_format}", category=RuntimeWarning
)
return

for i, item in enumerate(batch):
if isinstance(item, torch.Tensor):
batch[i] = item.to(memory_format=self.memory_format)

benefitial_layers = (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.Conv2d, torch.nn.Conv3d)

def has_layer_benefiting_from_channels_last(self, model: torch.nn.Module) -> bool:
return any(isinstance(layer, self.benefitial_layers) for layer in model.modules())
76 changes: 76 additions & 0 deletions tests/tests_pytorch/callbacks/test_memory_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from unittest.mock import MagicMock

import torch

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import MemoryFormat


def test_memory_format_callback_setup():
class DummyModule(LightningModule):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, kernel_size=3)

def forward(self, x):
return self.conv(x)

model = DummyModule()

# create a dummy Trainer
trainer = Trainer(max_epochs=1, devices=1)

# create the callback
callback = MemoryFormat()

# call the setup method
callback.setup(trainer, model)

# check that the memory format is channels_last
assert model.conv.weight.is_contiguous(memory_format=torch.channels_last)


def test_memory_format_callback():
# create a mock LightningModule
trainer = MagicMock()
pl_module = MagicMock()

# create a MemoryFormat callback
memory_format_callback = MemoryFormat()

# check that the callback sets the memory format correctly
memory_format_callback.setup(trainer=trainer, pl_module=pl_module)
assert pl_module.to.call_args[1]["memory_format"] == torch.channels_last

# check that the callback resets the memory format correctly
memory_format_callback.teardown(trainer=trainer, pl_module=pl_module)
assert pl_module.to.call_args[1]["memory_format"] == torch.contiguous_format

# check that the callback warns if the model doesn't have any layers benefiting from channels_last
pl_module.modules.return_value = [torch.nn.Linear(10, 10)]
with warnings.catch_warnings(record=True) as w:
memory_format_callback.setup(trainer=trainer, pl_module=pl_module)
assert len(w) == 1
assert issubclass(w[-1].category, RuntimeWarning)
assert "model does not have any layers benefiting from" in str(w[-1].message)

# check that the callback converts input tensors to channels_last format
memory_format_callback.convert_input = True
batch = [torch.randn(16, 3, 32, 32), torch.randn(16, 3, 32, 32)]
memory_format_callback.on_train_batch_start(trainer=trainer, pl_module=pl_module, batch=batch, batch_idx=0)
for item in batch:
assert item.is_contiguous(memory_format=torch.channels_last)
Loading