@@ -147,29 +147,52 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
147147
148148
149149@pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
150- def test_tensorboard_log_graph (tmp_path , example_input_array ):
151- """Test that log graph works with both model.example_input_array and if array is passed externally."""
152- # TODO(fabric): Test both nn.Module and LightningModule
153- # TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
150+ def test_tensorboard_log_graph_plain_module (tmp_path , example_input_array ):
154151 model = BoringModel ()
155- if example_input_array is not None :
156- model .example_input_array = None
157-
158152 logger = TensorBoardLogger (tmp_path )
159153 logger ._experiment = Mock ()
154+
160155 logger .log_graph (model , example_input_array )
161156 if example_input_array is not None :
162157 logger .experiment .add_graph .assert_called_with (model , example_input_array )
158+ else :
159+ logger .experiment .add_graph .assert_not_called ()
160+
163161 logger ._experiment .reset_mock ()
164162
165- # model wrapped in `FabricModule`
166163 wrapped = _FabricModule (model , strategy = Mock ())
167164 logger .log_graph (wrapped , example_input_array )
168165 if example_input_array is not None :
169166 logger .experiment .add_graph .assert_called_with (model , example_input_array )
170167
171168
172- @pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = str (_TENSORBOARD_AVAILABLE ))
169+ @pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
170+ def test_tensorboard_log_graph_with_batch_transfer_hooks (tmp_path , example_input_array ):
171+ model = pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ()
172+ logger = TensorBoardLogger (tmp_path )
173+ logger ._experiment = Mock ()
174+
175+ with (
176+ mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
177+ mock .patch .object (model , "_apply_batch_transfer_handler" , return_value = example_input_array ) as transfer_mock ,
178+ ):
179+ logger .log_graph (model , example_input_array )
180+ logger ._experiment .reset_mock ()
181+
182+ wrapped = _FabricModule (model , strategy = Mock ())
183+ logger .log_graph (wrapped , example_input_array )
184+
185+ if example_input_array is not None :
186+ assert before_mock .call_count == 2
187+ assert transfer_mock .call_count == 2
188+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
189+ else :
190+ before_mock .assert_not_called ()
191+ transfer_mock .assert_not_called ()
192+ logger .experiment .add_graph .assert_not_called ()
193+
194+
195+ @pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = "tensorboard is required" )
173196def test_tensorboard_log_graph_warning_no_example_input_array (tmp_path ):
174197 """Test that log graph throws warning if model.example_input_array is None."""
175198 model = BoringModel ()
0 commit comments