Skip to content

Commit 637ac4b

Browse files
carmoccalantiga
authored andcommitted
Avoid inference_mode with torch.compile (#17215)
1 parent a46b2b1 commit 637ac4b

File tree

5 files changed

+27
-2
lines changed

5 files changed

+27
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133))
1212
- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))
13+
- Disable `torch.inference_mode` with `torch.compile` in PyTorch 2.0 ([#17215](https://github.com/Lightning-AI/lightning/pull/17215))
1314

1415
### Fixed
1516

src/lightning/pytorch/loops/utilities.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch import Tensor
2121

2222
import lightning.pytorch as pl
23-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
23+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_1_13
2424
from lightning.fabric.utilities.warnings import PossibleUserWarning
2525
from lightning.pytorch.accelerators import TPUAccelerator
2626
from lightning.pytorch.callbacks.timer import Timer
@@ -166,6 +166,9 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any:
166166
elif _TORCH_GREATER_EQUAL_1_13 and isinstance(self.trainer.strategy, FSDPStrategy):
167167
# https://github.com/pytorch/pytorch/issues/95957
168168
context_manager = torch.no_grad
169+
elif _TORCH_EQUAL_2_0 and self.trainer.lightning_module._compiler_ctx is not None:
170+
# avoid: `RuntimeError: Inference tensors do not track version counter` fixed in v2.1
171+
context_manager = torch.no_grad
169172
elif self.inference_mode:
170173
context_manager = torch.inference_mode
171174
else:

tests/tests_pytorch/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def restore_env_variables():
7575
"KMP_INIT_AT_FORK", # leaked since PyTorch 1.13
7676
"KMP_DUPLICATE_LIB_OK", # leaked since PyTorch 1.13
7777
"CRC32C_SW_MODE", # leaked by tensorboardX
78+
"TRITON_CACHE_DIR", # leaked by torch.compile
7879
# leaked by XLA
7980
"ALLOW_MULTIPLE_LIBTPU_LOAD",
8081
"GRPC_VERBOSITY",

tests/tests_pytorch/trainer/flags/test_inference_mode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818
import torch
1919

20+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
2021
from lightning.pytorch import Trainer
2122
from lightning.pytorch.demos.boring_classes import BoringModel
2223
from lightning.pytorch.loops import _Loop
@@ -86,4 +87,5 @@ def run(self):
8687
f.inference_mode = True
8788
with mock.patch("torch.inference_mode") as inference_mode_mock:
8889
f.run()
89-
inference_mode_mock.assert_called_once_with()
90+
if not _TORCH_EQUAL_2_0:
91+
inference_mode_mock.assert_called_once_with()

tests/tests_pytorch/utilities/test_compile.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,21 @@ def training_step(self, batch, batch_idx):
148148
trainer.fit(compiled_model)
149149

150150
assert set(trainer.callback_metrics) == {"loss"}
151+
152+
153+
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
154+
@RunIf(min_torch="2.0.0")
155+
def test_trainer_compiled_model_test(tmp_path):
156+
skip_if_unsupported()
157+
158+
model = BoringModel()
159+
compiled_model = torch.compile(model)
160+
161+
trainer = Trainer(
162+
default_root_dir=tmp_path,
163+
fast_dev_run=True,
164+
enable_checkpointing=False,
165+
enable_model_summary=False,
166+
enable_progress_bar=False,
167+
)
168+
trainer.test(compiled_model)

0 commit comments

Comments
 (0)