@@ -146,7 +146,9 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
146146 logger .log_hyperparams (hparams , metrics )
147147
148148
149- @pytest .mark .parametrize ("model_cls" , [BoringModel , pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ])
149+ @pytest .mark .parametrize (
150+ "model_cls" , [BoringModel , pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ]
151+ )
150152@pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
151153def test_tensorboard_log_graph (tmp_path , example_input_array , model_cls ):
152154 """Test that log graph works with both model.example_input_array and if array is passed externally."""
@@ -160,7 +162,9 @@ def test_tensorboard_log_graph(tmp_path, example_input_array, model_cls):
160162 if isinstance (model , torch .nn .Module ) and hasattr (model , "_apply_batch_transfer_handler" ):
161163 with (
162164 mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
163- mock .patch .object (model , "_apply_batch_transfer_handler" , return_value = example_input_array ) as transfer_mock ,
165+ mock .patch .object (
166+ model , "_apply_batch_transfer_handler" , return_value = example_input_array
167+ ) as transfer_mock ,
164168 ):
165169 logger .log_graph (model , example_input_array )
166170 logger ._experiment .reset_mock ()
0 commit comments