@@ -46,18 +46,14 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
46
46
47
47
model = BoringModel ()
48
48
compiled_model = torch .compile (model )
49
- assert model ._compiler_ctx is compiled_model ._compiler_ctx # shared reference
50
49
51
50
# can train with compiled model
52
51
trainer = Trainer (** trainer_kwargs )
53
52
trainer .fit (compiled_model )
54
- assert trainer .model . _compiler_ctx [ "compiler" ] == "dynamo"
53
+ assert isinstance ( trainer .strategy . model , torch . _dynamo . OptimizedModule )
55
54
56
55
# the compiled model can be uncompiled
57
56
to_uncompiled_model = to_uncompiled (compiled_model )
58
- assert model ._compiler_ctx is None
59
- assert compiled_model ._compiler_ctx is None
60
- assert to_uncompiled_model ._compiler_ctx is None
61
57
62
58
# the compiled model needs to be passed
63
59
with pytest .raises (ValueError , match = "required to be a compiled LightningModule" ):
@@ -66,7 +62,7 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
66
62
# the uncompiled model can be fitted
67
63
trainer = Trainer (** trainer_kwargs )
68
64
trainer .fit (model )
69
- assert trainer .model . _compiler_ctx is None
65
+ assert not isinstance ( trainer .strategy . model , torch . _dynamo . OptimizedModule )
70
66
71
67
# some strategies do not support it
72
68
if RequirementCache ("deepspeed" ):
0 commit comments