Skip to content

Commit 1f32923

Browse files
rohitgr7justusschockawaelchli
authored andcommitted
Avoid torch amp cuda warning with bf16 on cpu (#11161)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent ff9c3f9 commit 1f32923

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116))
1414
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))
1515
- Fixed an `AttributeError` occuring when using a `CombinedLoader` (multiple dataloaders) for prediction ([#11111](https://github.com/PyTorchLightning/pytorch-lightning/pull/11111))
16-
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))
16+
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)` would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))
17+
- Fixed an incorrect warning being produced by the model summary when using `bf16` precision on CPU ([#11161](https://github.com/PyTorchLightning/pytorch-lightning/pull/11161))
1718

1819
### Changed
1920

pytorch_lightning/utilities/model_summary.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import logging
17+
import sys
1618
from collections import OrderedDict
1719
from typing import Any, Dict, List, Optional, Tuple, Union
1820

@@ -23,7 +25,7 @@
2325
from torch.utils.hooks import RemovableHandle
2426

2527
import pytorch_lightning as pl
26-
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
28+
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_deprecation
2729
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2830
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
2931
from pytorch_lightning.utilities.warnings import WarningCache
@@ -282,12 +284,17 @@ def _forward_example_input(self) -> None:
282284
input_ = model.example_input_array
283285
input_ = model._apply_batch_transfer_handler(input_)
284286

285-
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU:
286-
model.forward = torch.cuda.amp.autocast()(model.forward)
287-
288287
mode = model.training
289288
model.eval()
290-
with torch.no_grad():
289+
290+
if trainer is not None:
291+
forward_context = trainer.precision_plugin.forward_context()
292+
elif sys.version_info >= (3, 7):
293+
forward_context = contextlib.nullcontext()
294+
else:
295+
forward_context = contextlib.suppress()
296+
297+
with torch.no_grad(), forward_context:
291298
# let the model hooks collect the input- and output shapes
292299
if isinstance(input_, (list, tuple)):
293300
model(*input_)

0 commit comments

Comments
 (0)