2121
2222from lightning .fabric .utilities .cloud_io import get_filesystem
2323from lightning .fabric .utilities .imports import _IS_WINDOWS , _TORCH_GREATER_EQUAL_2_4
24+ from lightning .fabric .utilities .rank_zero import LightningDeprecationWarning
2425from lightning .pytorch .core .module import LightningModule
2526from lightning .pytorch .demos .boring_classes import BoringModel
2627from tests_pytorch .helpers .advanced_models import BasicGAN , ParityModuleRNN
@@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass):
3637 if isinstance (model , BoringModel ):
3738 model .example_input_array = torch .randn (5 , 32 )
3839
39- script = model .to_torchscript ()
40+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
41+ script = model .to_torchscript ()
4042 assert isinstance (script , torch .jit .ScriptModule )
4143
4244 model .eval ()
@@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass):
5961 if isinstance (model , BoringModel ):
6062 model .example_input_array = torch .randn (5 , 32 )
6163
62- script = model .to_torchscript (method = "trace" )
64+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
65+ script = model .to_torchscript (method = "trace" )
6366 assert isinstance (script , torch .jit .ScriptModule )
6467
6568 model .eval ()
@@ -74,7 +77,8 @@ def test_torchscript_input_output_trace():
7477 """Test that traced LightningModule forward works with example_inputs."""
7578 model = BoringModel ()
7679 example_inputs = torch .randn (1 , 32 )
77- script = model .to_torchscript (example_inputs = example_inputs , method = "trace" )
80+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
81+ script = model .to_torchscript (example_inputs = example_inputs , method = "trace" )
7882 assert isinstance (script , torch .jit .ScriptModule )
7983
8084 model .eval ()
@@ -99,7 +103,8 @@ def test_torchscript_device(device_str):
99103 model = BoringModel ().to (device )
100104 model .example_input_array = torch .randn (5 , 32 )
101105
102- script = model .to_torchscript ()
106+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
107+ script = model .to_torchscript ()
103108 assert next (script .parameters ()).device == device
104109 script_output = script (model .example_input_array .to (device ))
105110 assert script_output .device == device
@@ -121,19 +126,22 @@ def test_torchscript_device_with_check_inputs(device_str):
121126
122127 check_inputs = torch .rand (5 , 32 )
123128
124- script = model .to_torchscript (method = "trace" , check_inputs = check_inputs )
129+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
130+ script = model .to_torchscript (method = "trace" , check_inputs = check_inputs )
125131 assert isinstance (script , torch .jit .ScriptModule )
126132
127133
128134def test_torchscript_retain_training_state ():
129135 """Test that torchscript export does not alter the training mode of original model."""
130136 model = BoringModel ()
131137 model .train (True )
132- script = model .to_torchscript ()
138+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
139+ script = model .to_torchscript ()
133140 assert model .training
134141 assert not script .training
135142 model .train (False )
136- _ = model .to_torchscript ()
143+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
144+ _ = model .to_torchscript ()
137145 assert not model .training
138146 assert not script .training
139147
@@ -142,7 +150,8 @@ def test_torchscript_retain_training_state():
142150def test_torchscript_properties (modelclass ):
143151 """Test that scripted LightningModule has unnecessary methods removed."""
144152 model = modelclass ()
145- script = model .to_torchscript ()
153+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
154+ script = model .to_torchscript ()
146155 assert not hasattr (model , "batch_size" ) or hasattr (script , "batch_size" )
147156 assert not hasattr (model , "learning_rate" ) or hasattr (script , "learning_rate" )
148157 assert not callable (getattr (script , "training_step" , None ))
@@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass):
153162 """Test that scripted LightningModule is correctly saved and can be loaded."""
154163 model = modelclass ()
155164 output_file = str (tmp_path / "model.pt" )
156- script = model .to_torchscript (file_path = output_file )
165+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
166+ script = model .to_torchscript (file_path = output_file )
157167 loaded_script = torch .jit .load (output_file )
158168 assert torch .allclose (next (script .parameters ()), next (loaded_script .parameters ()))
159169
@@ -170,7 +180,8 @@ class DummyFileSystem(LocalFileSystem): ...
170180
171181 model = modelclass ()
172182 output_file = os .path .join (_DUMMY_PRFEIX , _PREFIX_SEPARATOR , tmp_path , "model.pt" )
173- script = model .to_torchscript (file_path = output_file )
183+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
184+ script = model .to_torchscript (file_path = output_file )
174185
175186 fs = get_filesystem (output_file )
176187 with fs .open (output_file , "rb" ) as f :
@@ -184,7 +195,10 @@ def test_torchcript_invalid_method():
184195 model = BoringModel ()
185196 model .train (True )
186197
187- with pytest .raises (ValueError , match = "only supports 'script' or 'trace'" ):
198+ with (
199+ pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ),
200+ pytest .raises (ValueError , match = "only supports 'script' or 'trace'" ),
201+ ):
188202 model .to_torchscript (method = "temp" )
189203
190204
@@ -193,7 +207,10 @@ def test_torchscript_with_no_input():
193207 model = BoringModel ()
194208 model .example_input_array = None
195209
196- with pytest .raises (ValueError , match = "requires either `example_inputs` or `model.example_input_array`" ):
210+ with (
211+ pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ),
212+ pytest .raises (ValueError , match = "requires either `example_inputs` or `model.example_input_array`" ),
213+ ):
197214 model .to_torchscript (method = "trace" )
198215
199216
@@ -224,6 +241,17 @@ def forward(self, inputs):
224241
225242 lm = Parent ()
226243 assert not lm ._jit_is_scripting
227- script = lm .to_torchscript (method = "script" )
244+ with pytest .deprecated_call (match = "has been deprecated in v2.7 and will be removed in v2.8" ):
245+ script = lm .to_torchscript (method = "script" )
228246 assert not lm ._jit_is_scripting
229247 assert isinstance (script , torch .jit .RecursiveScriptModule )
248+
249+
250+ def test_to_torchscript_deprecation ():
251+ """Test that to_torchscript raises a deprecation warning."""
252+ model = BoringModel ()
253+ model .example_input_array = torch .randn (5 , 32 )
254+
255+ with pytest .warns (LightningDeprecationWarning , match = "has been deprecated in v2.7 and will be removed in v2.8" ):
256+ script = model .to_torchscript ()
257+ assert isinstance (script , torch .jit .ScriptModule )
0 commit comments