@@ -324,19 +324,33 @@ def test_empty_model_size(max_depth):
324324 pytest .param ("mps" , marks = RunIf (mps = True )),
325325 ],
326326)
327- def test_model_size_precision (tmp_path , accelerator ):
328- """Test model size for half and full precision."""
329- model = PreCalculatedModel ()
327+ @pytest .mark .parametrize ("precision" , ["16-true" , "32-true" , "64-true" ])
328+ def test_model_size_precision (tmp_path , accelerator , precision ):
329+ """Test model size for different precision types."""
330+ model = PreCalculatedModel (precision = int (precision .split ("-" )[0 ]))
330331
331332 # fit model
332333 trainer = Trainer (
333- default_root_dir = tmp_path , accelerator = accelerator , devices = 1 , max_steps = 1 , max_epochs = 1 , precision = 32
334+ default_root_dir = tmp_path , accelerator = accelerator , devices = 1 , max_steps = 1 , max_epochs = 1 , precision = precision
334335 )
335336 trainer .fit (model )
336337 summary = summarize (model )
337338 assert model .pre_calculated_model_size == summary .model_size
338339
339340
341+ def test_model_size_warning_on_unsupported_precision ():
342+ """Test that a warning is raised when the precision is not supported."""
343+ model = PreCalculatedModel (precision = 32 ) # fallback to 32 bits
344+
345+ # supported precision by lightning but not by the model summary
346+ trainer = Trainer (max_epochs = 1 , precision = "16-mixed" )
347+ trainer .fit (model )
348+
349+ with pytest .warns (UserWarning , match = "Precision 16-mixed is not supported by the model summary.*" ):
350+ summary = summarize (model )
351+ assert model .pre_calculated_model_size == summary .model_size
352+
353+
340354def test_lazy_model_summary ():
341355 """Test that the model summary can work with lazy layers."""
342356 lazy_model = LazyModel ()
0 commit comments