Skip to content

Commit 518f759

Browse files
committed
test: enhance tensorboard log graph
1 parent 0773eb4 commit 518f759

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

tests/tests_fabric/loggers/test_tensorboard.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -146,30 +146,47 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
146146
logger.log_hyperparams(hparams, metrics)
147147

148148

149+
@pytest.mark.parametrize("model_cls", [BoringModel, pytest.importorskip("lightning.pytorch.demos.boring_classes").BoringModel])
149150
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
150-
def test_tensorboard_log_graph(tmp_path, example_input_array):
151+
def test_tensorboard_log_graph(tmp_path, example_input_array, model_cls):
151152
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
152-
# TODO(fabric): Test both nn.Module and LightningModule
153-
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
154-
model = BoringModel()
153+
model = model_cls()
155154
if example_input_array is not None:
156155
model.example_input_array = None
157156

158157
logger = TensorBoardLogger(tmp_path)
159158
logger._experiment = Mock()
160-
logger.log_graph(model, example_input_array)
161-
if example_input_array is not None:
162-
logger.experiment.add_graph.assert_called_with(model, example_input_array)
163-
logger._experiment.reset_mock()
164159

165-
# model wrapped in `FabricModule`
166-
wrapped = _FabricModule(model, strategy=Mock())
167-
logger.log_graph(wrapped, example_input_array)
168-
if example_input_array is not None:
169-
logger.experiment.add_graph.assert_called_with(model, example_input_array)
160+
if isinstance(model, torch.nn.Module) and hasattr(model, "_apply_batch_transfer_handler"):
161+
with (
162+
mock.patch.object(model, "_on_before_batch_transfer", return_value=example_input_array) as before_mock,
163+
mock.patch.object(model, "_apply_batch_transfer_handler", return_value=example_input_array) as transfer_mock,
164+
):
165+
logger.log_graph(model, example_input_array)
166+
logger._experiment.reset_mock()
167+
wrapped = _FabricModule(model, strategy=Mock())
168+
logger.log_graph(wrapped, example_input_array)
169+
if example_input_array is not None:
170+
assert before_mock.call_count == 2
171+
assert transfer_mock.call_count == 2
172+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
173+
else:
174+
before_mock.assert_not_called()
175+
transfer_mock.assert_not_called()
176+
logger.experiment.add_graph.assert_not_called()
177+
else:
178+
logger.log_graph(model, example_input_array)
179+
if example_input_array is not None:
180+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
181+
logger._experiment.reset_mock()
182+
183+
wrapped = _FabricModule(model, strategy=Mock())
184+
logger.log_graph(wrapped, example_input_array)
185+
if example_input_array is not None:
186+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
170187

171188

172-
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
189+
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason="tensorboard is required")
173190
def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path):
174191
"""Test that log graph throws warning if model.example_input_array is None."""
175192
model = BoringModel()

0 commit comments

Comments
 (0)