-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add MemoryFormat
callback (channels last)
#17680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Pedrexus
wants to merge
26
commits into
Lightning-AI:master
Choose a base branch
from
Pedrexus:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
fbe1bed
channels last callback
Pedrexus 2be15bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2e5daba
type hint fix
Pedrexus 3365b5b
changelog updated
Pedrexus eef597b
test import error fix
Pedrexus 896879a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a28c22f
Merge branch 'master' into master
Borda 97093d2
comments have been addressed
Pedrexus 9f770ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 13ffbff
Merge branch 'master' into master
Pedrexus 7a3ab21
Merge branch 'master' into master
Borda ff0dfe0
Merge branch 'master' into master
Borda 3d3e2e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 54def19
Merge remote-tracking branch 'upstream/master'
Pedrexus f837d47
import fixes
Pedrexus 03a252f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 10bd4f5
Merge branch 'master' into master
Pedrexus 3f4c044
Merge branch 'master' into master
Borda dadb6c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7607e0b
Merge branch 'master' into master
Borda f6b3468
Merge branch 'master' into master
Borda 35d3f17
Merge branch 'master' into master
Borda 2a20155
Merge branch 'master' into master
Borda d524717
Apply suggestions from code review
Borda 059d2af
Merge branch 'master' into master
Borda 84c0f23
Merge branch 'master' into master
Borda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
r""" | ||
MemoryFormat | ||
=============== | ||
|
||
changes the model memory format | ||
""" | ||
|
||
from collections.abc import MutableSequence | ||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
import lightning.pytorch as pl | ||
from lightning.pytorch.callbacks import Callback | ||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn | ||
|
||
|
||
class MemoryFormat(Callback): | ||
"""The `MemoryFormat` callback changes the model memory format to `torch.channels_last` before training starts and | ||
returns the original when it ends. | ||
|
||
<https://\\pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_. | ||
|
||
Setting the memory format channels_last usually improves GPU utilization. | ||
|
||
Runs on setup, so it can set the memory format before the model is DDP wrapped. | ||
|
||
""" | ||
|
||
def __init__(self, memory_format: torch.memory_format = torch.channels_last, convert_input: bool = False): | ||
self.memory_format = memory_format | ||
self.convert_input = convert_input | ||
|
||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: | ||
if self.memory_format in ( | ||
torch.channels_last, | ||
torch.channels_last_3d, | ||
) and not self.has_layer_benefiting_from_channels_last(pl_module): | ||
rank_zero_warn( | ||
f"model does not have any layers benefiting from {self.memory_format} format", category=RuntimeWarning | ||
) | ||
|
||
pl_module.to(memory_format=self.memory_format) | ||
|
||
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: | ||
pl_module.to(memory_format=torch.contiguous_format) | ||
|
||
def on_train_batch_start( | ||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int | ||
) -> None: | ||
if not self.convert_input: | ||
return | ||
|
||
if not isinstance(batch, MutableSequence): | ||
rank_zero_warn( | ||
f"batch is not a MutableSequence, cannot convert input to {self.memory_format}", category=RuntimeWarning | ||
) | ||
return | ||
|
||
for i, item in enumerate(batch): | ||
if isinstance(item, torch.Tensor): | ||
batch[i] = item.to(memory_format=self.memory_format) | ||
|
||
benefitial_layers = (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.Conv2d, torch.nn.Conv3d) | ||
|
||
def has_layer_benefiting_from_channels_last(self, model: torch.nn.Module) -> bool: | ||
return any(isinstance(layer, self.benefitial_layers) for layer in model.modules()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import warnings | ||
from unittest.mock import MagicMock | ||
|
||
import torch | ||
|
||
from lightning.pytorch import LightningModule, Trainer | ||
from lightning.pytorch.callbacks import MemoryFormat | ||
|
||
|
||
def test_memory_format_callback_setup(): | ||
class DummyModule(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = torch.nn.Conv2d(3, 3, kernel_size=3) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
model = DummyModule() | ||
|
||
# create a dummy Trainer | ||
trainer = Trainer(max_epochs=1, devices=1) | ||
|
||
# create the callback | ||
callback = MemoryFormat() | ||
|
||
# call the setup method | ||
callback.setup(trainer, model) | ||
|
||
# check that the memory format is channels_last | ||
assert model.conv.weight.is_contiguous(memory_format=torch.channels_last) | ||
|
||
|
||
def test_memory_format_callback(): | ||
# create a mock LightningModule | ||
trainer = MagicMock() | ||
pl_module = MagicMock() | ||
|
||
# create a MemoryFormat callback | ||
memory_format_callback = MemoryFormat() | ||
|
||
# check that the callback sets the memory format correctly | ||
memory_format_callback.setup(trainer=trainer, pl_module=pl_module) | ||
assert pl_module.to.call_args[1]["memory_format"] == torch.channels_last | ||
|
||
# check that the callback resets the memory format correctly | ||
memory_format_callback.teardown(trainer=trainer, pl_module=pl_module) | ||
assert pl_module.to.call_args[1]["memory_format"] == torch.contiguous_format | ||
|
||
# check that the callback warns if the model doesn't have any layers benefiting from channels_last | ||
pl_module.modules.return_value = [torch.nn.Linear(10, 10)] | ||
with warnings.catch_warnings(record=True) as w: | ||
memory_format_callback.setup(trainer=trainer, pl_module=pl_module) | ||
assert len(w) == 1 | ||
assert issubclass(w[-1].category, RuntimeWarning) | ||
assert "model does not have any layers benefiting from" in str(w[-1].message) | ||
|
||
# check that the callback converts input tensors to channels_last format | ||
memory_format_callback.convert_input = True | ||
batch = [torch.randn(16, 3, 32, 32), torch.randn(16, 3, 32, 32)] | ||
memory_format_callback.on_train_batch_start(trainer=trainer, pl_module=pl_module, batch=batch, batch_idx=0) | ||
for item in batch: | ||
assert item.is_contiguous(memory_format=torch.channels_last) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Pedrexus, just thinking if the name is right, since it rather does
ReorderImageDimensions