Skip to content

Commit 8e73a21

Browse files
authored
test_trainer_compiled_model change
Don't use compiler_ctx in case of OptimizedModule, unwrapping using fabric don't fill these fields
1 parent b17a3dc commit 8e73a21

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tests/tests_pytorch/utilities/test_compile.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,14 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
4646

4747
model = BoringModel()
4848
compiled_model = torch.compile(model)
49-
assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference
5049

5150
# can train with compiled model
5251
trainer = Trainer(**trainer_kwargs)
5352
trainer.fit(compiled_model)
54-
assert trainer.model._compiler_ctx["compiler"] == "dynamo"
53+
assert isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule)
5554

5655
# the compiled model can be uncompiled
5756
to_uncompiled_model = to_uncompiled(compiled_model)
58-
assert model._compiler_ctx is None
59-
assert compiled_model._compiler_ctx is None
60-
assert to_uncompiled_model._compiler_ctx is None
6157

6258
# the compiled model needs to be passed
6359
with pytest.raises(ValueError, match="required to be a compiled LightningModule"):
@@ -66,7 +62,7 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
6662
# the uncompiled model can be fitted
6763
trainer = Trainer(**trainer_kwargs)
6864
trainer.fit(model)
69-
assert trainer.model._compiler_ctx is None
65+
assert not isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule)
7066

7167
# some strategies do not support it
7268
if RequirementCache("deepspeed"):

0 commit comments

Comments
 (0)