|
34 | 34 | @pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found")
|
35 | 35 | @RunIf(dynamo=True, deepspeed=True)
|
36 | 36 | @mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
|
37 |
| -def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): |
| 37 | +def test_trainer_compiled_model_deepspeed(_, tmp_path, monkeypatch, mps_count_0): |
38 | 38 | trainer_kwargs = {
|
39 | 39 | "default_root_dir": tmp_path,
|
40 | 40 | "fast_dev_run": True,
|
@@ -69,22 +69,59 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
|
69 | 69 | assert trainer.model._compiler_ctx is None
|
70 | 70 |
|
71 | 71 | # some strategies do not support it
|
72 |
| - if RequirementCache("deepspeed"): |
73 |
| - compiled_model = torch.compile(model) |
74 |
| - mock_cuda_count(monkeypatch, 2) |
75 |
| - |
76 |
| - # TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import |
77 |
| - warn_context = ( |
78 |
| - pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated") |
79 |
| - if _TORCH_GREATER_EQUAL_2_4 |
80 |
| - else nullcontext() |
81 |
| - ) |
82 |
| - |
83 |
| - with warn_context: |
84 |
| - trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs) |
85 |
| - |
86 |
| - with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"): |
87 |
| - trainer.fit(compiled_model) |
| 72 | + compiled_model = torch.compile(model) |
| 73 | + mock_cuda_count(monkeypatch, 2) |
| 74 | + |
| 75 | + # TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import |
| 76 | + warn_context = ( |
| 77 | + pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated") |
| 78 | + if _TORCH_GREATER_EQUAL_2_4 |
| 79 | + else nullcontext() |
| 80 | + ) |
| 81 | + |
| 82 | + with warn_context: |
| 83 | + trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs) |
| 84 | + |
| 85 | + with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"): |
| 86 | + trainer.fit(compiled_model) |
| 87 | + |
| 88 | +# https://github.com/pytorch/pytorch/issues/95708 |
| 89 | +@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found") |
| 90 | +@RunIf(dynamo=True) |
| 91 | +@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt") |
| 92 | +def test_trainer_compiled_model_ddp(_, tmp_path, monkeypatch, mps_count_0): |
| 93 | + trainer_kwargs = { |
| 94 | + "default_root_dir": tmp_path, |
| 95 | + "fast_dev_run": True, |
| 96 | + "logger": False, |
| 97 | + "enable_checkpointing": False, |
| 98 | + "enable_model_summary": False, |
| 99 | + "enable_progress_bar": False, |
| 100 | + } |
| 101 | + |
| 102 | + model = BoringModel() |
| 103 | + compiled_model = torch.compile(model) |
| 104 | + assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference |
| 105 | + |
| 106 | + # can train with compiled model |
| 107 | + trainer = Trainer(**trainer_kwargs) |
| 108 | + trainer.fit(compiled_model) |
| 109 | + assert trainer.model._compiler_ctx["compiler"] == "dynamo" |
| 110 | + |
| 111 | + # the compiled model can be uncompiled |
| 112 | + to_uncompiled_model = to_uncompiled(compiled_model) |
| 113 | + assert model._compiler_ctx is None |
| 114 | + assert compiled_model._compiler_ctx is None |
| 115 | + assert to_uncompiled_model._compiler_ctx is None |
| 116 | + |
| 117 | + # the compiled model needs to be passed |
| 118 | + with pytest.raises(ValueError, match="required to be a compiled LightningModule"): |
| 119 | + to_uncompiled(to_uncompiled_model) |
| 120 | + |
| 121 | + # the uncompiled model can be fitted |
| 122 | + trainer = Trainer(**trainer_kwargs) |
| 123 | + trainer.fit(model) |
| 124 | + assert trainer.model._compiler_ctx is None |
88 | 125 |
|
89 | 126 | # ddp does
|
90 | 127 | trainer = Trainer(strategy="ddp", **trainer_kwargs)
|
|
0 commit comments