Skip to content

Commit 4c4ec25

Browse files
committed
split test_trainer_compiled_model
1 parent c18a2ba commit 4c4ec25

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

tests/tests_pytorch/utilities/test_compile.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found")
3535
@RunIf(dynamo=True, deepspeed=True)
3636
@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
37-
def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
37+
def test_trainer_compiled_model_deepspeed(_, tmp_path, monkeypatch, mps_count_0):
3838
trainer_kwargs = {
3939
"default_root_dir": tmp_path,
4040
"fast_dev_run": True,
@@ -69,22 +69,59 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
6969
assert trainer.model._compiler_ctx is None
7070

7171
# some strategies do not support it
72-
if RequirementCache("deepspeed"):
73-
compiled_model = torch.compile(model)
74-
mock_cuda_count(monkeypatch, 2)
75-
76-
# TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import
77-
warn_context = (
78-
pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated")
79-
if _TORCH_GREATER_EQUAL_2_4
80-
else nullcontext()
81-
)
82-
83-
with warn_context:
84-
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
85-
86-
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
87-
trainer.fit(compiled_model)
72+
compiled_model = torch.compile(model)
73+
mock_cuda_count(monkeypatch, 2)
74+
75+
# TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import
76+
warn_context = (
77+
pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated")
78+
if _TORCH_GREATER_EQUAL_2_4
79+
else nullcontext()
80+
)
81+
82+
with warn_context:
83+
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
84+
85+
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
86+
trainer.fit(compiled_model)
87+
88+
# https://github.com/pytorch/pytorch/issues/95708
89+
@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found")
90+
@RunIf(dynamo=True)
91+
@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
92+
def test_trainer_compiled_model_ddp(_, tmp_path, monkeypatch, mps_count_0):
93+
trainer_kwargs = {
94+
"default_root_dir": tmp_path,
95+
"fast_dev_run": True,
96+
"logger": False,
97+
"enable_checkpointing": False,
98+
"enable_model_summary": False,
99+
"enable_progress_bar": False,
100+
}
101+
102+
model = BoringModel()
103+
compiled_model = torch.compile(model)
104+
assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference
105+
106+
# can train with compiled model
107+
trainer = Trainer(**trainer_kwargs)
108+
trainer.fit(compiled_model)
109+
assert trainer.model._compiler_ctx["compiler"] == "dynamo"
110+
111+
# the compiled model can be uncompiled
112+
to_uncompiled_model = to_uncompiled(compiled_model)
113+
assert model._compiler_ctx is None
114+
assert compiled_model._compiler_ctx is None
115+
assert to_uncompiled_model._compiler_ctx is None
116+
117+
# the compiled model needs to be passed
118+
with pytest.raises(ValueError, match="required to be a compiled LightningModule"):
119+
to_uncompiled(to_uncompiled_model)
120+
121+
# the uncompiled model can be fitted
122+
trainer = Trainer(**trainer_kwargs)
123+
trainer.fit(model)
124+
assert trainer.model._compiler_ctx is None
88125

89126
# ddp does
90127
trainer = Trainer(strategy="ddp", **trainer_kwargs)

0 commit comments

Comments
 (0)