Skip to content

Commit 67b94ef

Browse files
authored
Avoid inference_mode with FSDP (#17064)
1 parent 8434ee7 commit 67b94ef

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/lightning/pytorch/loops/utilities.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@
1313
# limitations under the License.
1414
import inspect
1515
from contextlib import contextmanager
16-
from typing import Any, Callable, Generator, Optional, Tuple
16+
from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type
1717

1818
import torch
1919
import torch.distributed as dist
2020
from torch import Tensor
2121

2222
import lightning.pytorch as pl
23+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
2324
from lightning.fabric.utilities.warnings import PossibleUserWarning
2425
from lightning.pytorch.accelerators import TPUAccelerator
2526
from lightning.pytorch.callbacks.timer import Timer
2627
from lightning.pytorch.loops import _Loop
2728
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher
2829
from lightning.pytorch.loops.progress import _BaseProgress
30+
from lightning.pytorch.strategies import FSDPStrategy
2931
from lightning.pytorch.strategies.parallel import ParallelStrategy
3032
from lightning.pytorch.strategies.strategy import Strategy
3133
from lightning.pytorch.trainer.states import RunningStage
@@ -153,16 +155,21 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any:
153155
raise TypeError(f"`{type(self).__name__}` needs to be a Loop.")
154156
if not hasattr(self, "inference_mode"):
155157
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined")
156-
context_manager = (
157-
torch.inference_mode
158-
if (
159-
self.inference_mode
160-
# inference mode is not supported with gloo backend (#9431) and TPU accelerators.
161-
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
162-
and not isinstance(self.trainer.accelerator, TPUAccelerator)
163-
)
164-
else torch.no_grad
165-
)
158+
context_manager: Type[ContextManager]
159+
if dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo":
160+
# gloo backend does not work properly.
161+
# https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110
162+
# TODO: explore why and possibly open an issue in PyTorch repository
163+
context_manager = torch.no_grad
164+
elif isinstance(self.trainer.accelerator, TPUAccelerator):
165+
context_manager = torch.no_grad
166+
elif _TORCH_GREATER_EQUAL_1_13 and isinstance(self.trainer.strategy, FSDPStrategy):
167+
# https://github.com/pytorch/pytorch/issues/95957
168+
context_manager = torch.no_grad
169+
elif self.inference_mode:
170+
context_manager = torch.inference_mode
171+
else:
172+
context_manager = torch.no_grad
166173
with context_manager():
167174
return loop_run(self, *args, **kwargs)
168175

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy):
247247
limit_test_batches=2,
248248
limit_predict_batches=2,
249249
callbacks=[ck],
250-
inference_mode=not _TORCH_GREATER_EQUAL_2_0, # TODO(carmocca): inference_mode raises RuntimeError
251250
)
252251
_run_multiple_stages(trainer, model)
253252

0 commit comments

Comments
 (0)