|
13 | 13 | # limitations under the License. |
14 | 14 | import inspect |
15 | 15 | from contextlib import contextmanager |
16 | | -from typing import Any, Callable, Generator, Optional, Tuple |
| 16 | +from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.distributed as dist |
20 | 20 | from torch import Tensor |
21 | 21 |
|
22 | 22 | import lightning.pytorch as pl |
| 23 | +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13 |
23 | 24 | from lightning.fabric.utilities.warnings import PossibleUserWarning |
24 | 25 | from lightning.pytorch.accelerators import TPUAccelerator |
25 | 26 | from lightning.pytorch.callbacks.timer import Timer |
26 | 27 | from lightning.pytorch.loops import _Loop |
27 | 28 | from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher |
28 | 29 | from lightning.pytorch.loops.progress import _BaseProgress |
| 30 | +from lightning.pytorch.strategies import FSDPStrategy |
29 | 31 | from lightning.pytorch.strategies.parallel import ParallelStrategy |
30 | 32 | from lightning.pytorch.strategies.strategy import Strategy |
31 | 33 | from lightning.pytorch.trainer.states import RunningStage |
@@ -153,16 +155,21 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: |
153 | 155 | raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") |
154 | 156 | if not hasattr(self, "inference_mode"): |
155 | 157 | raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") |
156 | | - context_manager = ( |
157 | | - torch.inference_mode |
158 | | - if ( |
159 | | - self.inference_mode |
160 | | - # inference mode is not supported with gloo backend (#9431) and TPU accelerators. |
161 | | - and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo") |
162 | | - and not isinstance(self.trainer.accelerator, TPUAccelerator) |
163 | | - ) |
164 | | - else torch.no_grad |
165 | | - ) |
| 158 | + context_manager: Type[ContextManager] |
| 159 | + if dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo": |
| 160 | + # gloo backend does not work properly. |
| 161 | + # https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110 |
| 162 | + # TODO: explore why and possibly open an issue in PyTorch repository |
| 163 | + context_manager = torch.no_grad |
| 164 | + elif isinstance(self.trainer.accelerator, TPUAccelerator): |
| 165 | + context_manager = torch.no_grad |
| 166 | + elif _TORCH_GREATER_EQUAL_1_13 and isinstance(self.trainer.strategy, FSDPStrategy): |
| 167 | + # https://github.com/pytorch/pytorch/issues/95957 |
| 168 | + context_manager = torch.no_grad |
| 169 | + elif self.inference_mode: |
| 170 | + context_manager = torch.inference_mode |
| 171 | + else: |
| 172 | + context_manager = torch.no_grad |
166 | 173 | with context_manager(): |
167 | 174 | return loop_run(self, *args, **kwargs) |
168 | 175 |
|
|
0 commit comments