Skip to content

Commit 8e42016

Browse files
committed
add deprecation to tests
1 parent 6b9db2c commit 8e42016

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

tests/tests_pytorch/helpers/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def test_models(tmp_path, data_class, model_class):
4646
if dm is not None:
4747
trainer.test(model, datamodule=dm)
4848

49-
model.to_torchscript()
49+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
50+
model.to_torchscript()
5051
if data_class:
5152
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
5253

tests/tests_pytorch/models/test_torchscript.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from lightning.fabric.utilities.cloud_io import get_filesystem
2323
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
24+
from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning
2425
from lightning.pytorch.core.module import LightningModule
2526
from lightning.pytorch.demos.boring_classes import BoringModel
2627
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
@@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass):
3637
if isinstance(model, BoringModel):
3738
model.example_input_array = torch.randn(5, 32)
3839

39-
script = model.to_torchscript()
40+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
41+
script = model.to_torchscript()
4042
assert isinstance(script, torch.jit.ScriptModule)
4143

4244
model.eval()
@@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass):
5961
if isinstance(model, BoringModel):
6062
model.example_input_array = torch.randn(5, 32)
6163

62-
script = model.to_torchscript(method="trace")
64+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
65+
script = model.to_torchscript(method="trace")
6366
assert isinstance(script, torch.jit.ScriptModule)
6467

6568
model.eval()
@@ -74,7 +77,8 @@ def test_torchscript_input_output_trace():
7477
"""Test that traced LightningModule forward works with example_inputs."""
7578
model = BoringModel()
7679
example_inputs = torch.randn(1, 32)
77-
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
80+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
81+
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
7882
assert isinstance(script, torch.jit.ScriptModule)
7983

8084
model.eval()
@@ -99,7 +103,8 @@ def test_torchscript_device(device_str):
99103
model = BoringModel().to(device)
100104
model.example_input_array = torch.randn(5, 32)
101105

102-
script = model.to_torchscript()
106+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
107+
script = model.to_torchscript()
103108
assert next(script.parameters()).device == device
104109
script_output = script(model.example_input_array.to(device))
105110
assert script_output.device == device
@@ -121,19 +126,22 @@ def test_torchscript_device_with_check_inputs(device_str):
121126

122127
check_inputs = torch.rand(5, 32)
123128

124-
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
129+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
130+
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
125131
assert isinstance(script, torch.jit.ScriptModule)
126132

127133

128134
def test_torchscript_retain_training_state():
129135
"""Test that torchscript export does not alter the training mode of original model."""
130136
model = BoringModel()
131137
model.train(True)
132-
script = model.to_torchscript()
138+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
139+
script = model.to_torchscript()
133140
assert model.training
134141
assert not script.training
135142
model.train(False)
136-
_ = model.to_torchscript()
143+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
144+
_ = model.to_torchscript()
137145
assert not model.training
138146
assert not script.training
139147

@@ -142,7 +150,8 @@ def test_torchscript_retain_training_state():
142150
def test_torchscript_properties(modelclass):
143151
"""Test that scripted LightningModule has unnecessary methods removed."""
144152
model = modelclass()
145-
script = model.to_torchscript()
153+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
154+
script = model.to_torchscript()
146155
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
147156
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
148157
assert not callable(getattr(script, "training_step", None))
@@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass):
153162
"""Test that scripted LightningModule is correctly saved and can be loaded."""
154163
model = modelclass()
155164
output_file = str(tmp_path / "model.pt")
156-
script = model.to_torchscript(file_path=output_file)
165+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
166+
script = model.to_torchscript(file_path=output_file)
157167
loaded_script = torch.jit.load(output_file)
158168
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
159169

@@ -170,7 +180,8 @@ class DummyFileSystem(LocalFileSystem): ...
170180

171181
model = modelclass()
172182
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt")
173-
script = model.to_torchscript(file_path=output_file)
183+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
184+
script = model.to_torchscript(file_path=output_file)
174185

175186
fs = get_filesystem(output_file)
176187
with fs.open(output_file, "rb") as f:
@@ -184,7 +195,10 @@ def test_torchcript_invalid_method():
184195
model = BoringModel()
185196
model.train(True)
186197

187-
with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
198+
with (
199+
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
200+
pytest.raises(ValueError, match="only supports 'script' or 'trace'"),
201+
):
188202
model.to_torchscript(method="temp")
189203

190204

@@ -193,7 +207,10 @@ def test_torchscript_with_no_input():
193207
model = BoringModel()
194208
model.example_input_array = None
195209

196-
with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"):
210+
with (
211+
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
212+
pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"),
213+
):
197214
model.to_torchscript(method="trace")
198215

199216

@@ -224,6 +241,17 @@ def forward(self, inputs):
224241

225242
lm = Parent()
226243
assert not lm._jit_is_scripting
227-
script = lm.to_torchscript(method="script")
244+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
245+
script = lm.to_torchscript(method="script")
228246
assert not lm._jit_is_scripting
229247
assert isinstance(script, torch.jit.RecursiveScriptModule)
248+
249+
250+
def test_to_torchscript_deprecation():
251+
"""Test that to_torchscript raises a deprecation warning."""
252+
model = BoringModel()
253+
model.example_input_array = torch.randn(5, 32)
254+
255+
with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"):
256+
script = model.to_torchscript()
257+
assert isinstance(script, torch.jit.ScriptModule)

0 commit comments

Comments
 (0)