@@ -146,48 +146,50 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
146146 logger .log_hyperparams (hparams , metrics )
147147
148148
149- @pytest .mark .parametrize (
150- "model_cls" , [BoringModel , pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ]
151- )
152149@pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
153- def test_tensorboard_log_graph (tmp_path , example_input_array , model_cls ):
154- """Test that log graph works with both model.example_input_array and if array is passed externally."""
155- model = model_cls ()
150+ def test_tensorboard_log_graph_plain_module (tmp_path , example_input_array ):
151+ model = BoringModel ()
152+ logger = TensorBoardLogger (tmp_path )
153+ logger ._experiment = Mock ()
154+
155+ logger .log_graph (model , example_input_array )
156+ if example_input_array is not None :
157+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
158+ else :
159+ logger .experiment .add_graph .assert_not_called ()
160+
161+ logger ._experiment .reset_mock ()
162+
163+ wrapped = _FabricModule (model , strategy = Mock ())
164+ logger .log_graph (wrapped , example_input_array )
156165 if example_input_array is not None :
157- model . example_input_array = None
166+ logger . experiment . add_graph . assert_called_with ( model , example_input_array )
158167
168+
169+ @pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
170+ def test_tensorboard_log_graph_with_batch_transfer_hooks (tmp_path , example_input_array ):
171+ model = pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ()
159172 logger = TensorBoardLogger (tmp_path )
160173 logger ._experiment = Mock ()
161174
162- if isinstance (model , torch .nn .Module ) and hasattr (model , "_apply_batch_transfer_handler" ):
163- with (
164- mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
165- mock .patch .object (
166- model , "_apply_batch_transfer_handler" , return_value = example_input_array
167- ) as transfer_mock ,
168- ):
169- logger .log_graph (model , example_input_array )
170- logger ._experiment .reset_mock ()
171- wrapped = _FabricModule (model , strategy = Mock ())
172- logger .log_graph (wrapped , example_input_array )
173- if example_input_array is not None :
174- assert before_mock .call_count == 2
175- assert transfer_mock .call_count == 2
176- logger .experiment .add_graph .assert_called_with (model , example_input_array )
177- else :
178- before_mock .assert_not_called ()
179- transfer_mock .assert_not_called ()
180- logger .experiment .add_graph .assert_not_called ()
181- else :
175+ with (
176+ mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
177+ mock .patch .object (model , "_apply_batch_transfer_handler" , return_value = example_input_array ) as transfer_mock ,
178+ ):
182179 logger .log_graph (model , example_input_array )
183- if example_input_array is not None :
184- logger .experiment .add_graph .assert_called_with (model , example_input_array )
185180 logger ._experiment .reset_mock ()
186181
187182 wrapped = _FabricModule (model , strategy = Mock ())
188183 logger .log_graph (wrapped , example_input_array )
184+
189185 if example_input_array is not None :
186+ assert before_mock .call_count == 2
187+ assert transfer_mock .call_count == 2
190188 logger .experiment .add_graph .assert_called_with (model , example_input_array )
189+ else :
190+ before_mock .assert_not_called ()
191+ transfer_mock .assert_not_called ()
192+ logger .experiment .add_graph .assert_not_called ()
191193
192194
193195@pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = "tensorboard is required" )
0 commit comments