|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import contextlib |
15 | 16 | import logging
|
| 17 | +import sys |
16 | 18 | from collections import OrderedDict
|
17 | 19 | from typing import Any, Dict, List, Optional, Tuple, Union
|
18 | 20 |
|
|
23 | 25 | from torch.utils.hooks import RemovableHandle
|
24 | 26 |
|
25 | 27 | import pytorch_lightning as pl
|
26 |
| -from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation |
| 28 | +from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_deprecation |
27 | 29 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
28 | 30 | from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
|
29 | 31 | from pytorch_lightning.utilities.warnings import WarningCache
|
@@ -282,12 +284,17 @@ def _forward_example_input(self) -> None:
|
282 | 284 | input_ = model.example_input_array
|
283 | 285 | input_ = model._apply_batch_transfer_handler(input_)
|
284 | 286 |
|
285 |
| - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: |
286 |
| - model.forward = torch.cuda.amp.autocast()(model.forward) |
287 |
| - |
288 | 287 | mode = model.training
|
289 | 288 | model.eval()
|
290 |
| - with torch.no_grad(): |
| 289 | + |
| 290 | + if trainer is not None: |
| 291 | + forward_context = trainer.precision_plugin.forward_context() |
| 292 | + elif sys.version_info >= (3, 7): |
| 293 | + forward_context = contextlib.nullcontext() |
| 294 | + else: |
| 295 | + forward_context = contextlib.suppress() |
| 296 | + |
| 297 | + with torch.no_grad(), forward_context: |
291 | 298 | # let the model hooks collect the input- and output shapes
|
292 | 299 | if isinstance(input_, (list, tuple)):
|
293 | 300 | model(*input_)
|
|
0 commit comments