diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py index 4dcb86f0e7406..9cd61ef2e131b 100644 --- a/tests/tests_fabric/loggers/test_tensorboard.py +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -147,29 +147,52 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path): @pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)]) -def test_tensorboard_log_graph(tmp_path, example_input_array): - """Test that log graph works with both model.example_input_array and if array is passed externally.""" - # TODO(fabric): Test both nn.Module and LightningModule - # TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks +def test_tensorboard_log_graph_plain_module(tmp_path, example_input_array): model = BoringModel() - if example_input_array is not None: - model.example_input_array = None - logger = TensorBoardLogger(tmp_path) logger._experiment = Mock() + logger.log_graph(model, example_input_array) if example_input_array is not None: logger.experiment.add_graph.assert_called_with(model, example_input_array) + else: + logger.experiment.add_graph.assert_not_called() + logger._experiment.reset_mock() - # model wrapped in `FabricModule` wrapped = _FabricModule(model, strategy=Mock()) logger.log_graph(wrapped, example_input_array) if example_input_array is not None: logger.experiment.add_graph.assert_called_with(model, example_input_array) -@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE)) +@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)]) +def test_tensorboard_log_graph_with_batch_transfer_hooks(tmp_path, example_input_array): + model = pytest.importorskip("lightning.pytorch.demos.boring_classes").BoringModel() + logger = TensorBoardLogger(tmp_path) + logger._experiment = Mock() + + with ( + mock.patch.object(model, "_on_before_batch_transfer", return_value=example_input_array) as before_mock, + mock.patch.object(model, "_apply_batch_transfer_handler", return_value=example_input_array) as transfer_mock, + ): + logger.log_graph(model, example_input_array) + logger._experiment.reset_mock() + + wrapped = _FabricModule(model, strategy=Mock()) + logger.log_graph(wrapped, example_input_array) + + if example_input_array is not None: + assert before_mock.call_count == 2 + assert transfer_mock.call_count == 2 + logger.experiment.add_graph.assert_called_with(model, example_input_array) + else: + before_mock.assert_not_called() + transfer_mock.assert_not_called() + logger.experiment.add_graph.assert_not_called() + + +@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason="tensorboard is required") def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path): """Test that log graph throws warning if model.example_input_array is None.""" model = BoringModel()