@@ -146,30 +146,47 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
146146 logger .log_hyperparams (hparams , metrics )
147147
148148
149+ @pytest .mark .parametrize ("model_cls" , [BoringModel , pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ])
149150@pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
150- def test_tensorboard_log_graph (tmp_path , example_input_array ):
151+ def test_tensorboard_log_graph (tmp_path , example_input_array , model_cls ):
151152 """Test that log graph works with both model.example_input_array and if array is passed externally."""
152- # TODO(fabric): Test both nn.Module and LightningModule
153- # TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
154- model = BoringModel ()
153+ model = model_cls ()
155154 if example_input_array is not None :
156155 model .example_input_array = None
157156
158157 logger = TensorBoardLogger (tmp_path )
159158 logger ._experiment = Mock ()
160- logger .log_graph (model , example_input_array )
161- if example_input_array is not None :
162- logger .experiment .add_graph .assert_called_with (model , example_input_array )
163- logger ._experiment .reset_mock ()
164159
165- # model wrapped in `FabricModule`
166- wrapped = _FabricModule (model , strategy = Mock ())
167- logger .log_graph (wrapped , example_input_array )
168- if example_input_array is not None :
169- logger .experiment .add_graph .assert_called_with (model , example_input_array )
160+ if isinstance (model , torch .nn .Module ) and hasattr (model , "_apply_batch_transfer_handler" ):
161+ with (
162+ mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
163+ mock .patch .object (model , "_apply_batch_transfer_handler" , return_value = example_input_array ) as transfer_mock ,
164+ ):
165+ logger .log_graph (model , example_input_array )
166+ logger ._experiment .reset_mock ()
167+ wrapped = _FabricModule (model , strategy = Mock ())
168+ logger .log_graph (wrapped , example_input_array )
169+ if example_input_array is not None :
170+ assert before_mock .call_count == 2
171+ assert transfer_mock .call_count == 2
172+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
173+ else :
174+ before_mock .assert_not_called ()
175+ transfer_mock .assert_not_called ()
176+ logger .experiment .add_graph .assert_not_called ()
177+ else :
178+ logger .log_graph (model , example_input_array )
179+ if example_input_array is not None :
180+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
181+ logger ._experiment .reset_mock ()
182+
183+ wrapped = _FabricModule (model , strategy = Mock ())
184+ logger .log_graph (wrapped , example_input_array )
185+ if example_input_array is not None :
186+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
170187
171188
172- @pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = str ( _TENSORBOARD_AVAILABLE ) )
189+ @pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = "tensorboard is required" )
173190def test_tensorboard_log_graph_warning_no_example_input_array (tmp_path ):
174191 """Test that log graph throws warning if model.example_input_array is None."""
175192 model = BoringModel ()
0 commit comments