diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 208244dc38cd3..010e8b55636c3 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -211,8 +211,11 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) else: try: self.experiment.add_scalar(k, v, step) - # TODO(fabric): specify the possible exception - except Exception as ex: + except ( + NotImplementedError, + ValueError, + ModuleNotFoundError, # https://github.com/pytorch/pytorch/issues/24175 + ) as ex: raise ValueError( f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor." ) from ex diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py index 7e02a73c93082..751224b93d3a6 100644 --- a/tests/tests_pytorch/loggers/test_tensorboard.py +++ b/tests/tests_pytorch/loggers/test_tensorboard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import re from argparse import Namespace from unittest import mock from unittest.mock import Mock @@ -157,6 +158,18 @@ def test_tensorboard_log_metrics(tmp_path, step_idx): logger.log_metrics(metrics, step_idx) +@pytest.mark.parametrize("value", [[1], "x", None]) +def test_tensorboard_log_metrics_exception_message(tmp_path, value): + logger = TensorBoardLogger(tmp_path) + with pytest.raises( + ValueError, + match=re.escape( + f"you tried to log {value} which is currently not supported. Try a dict or a scalar/tensor.", + ), + ): + logger.log_metrics(metrics={"metric": value}) + + def test_tensorboard_log_hyperparams(tmp_path): logger = TensorBoardLogger(tmp_path) hparams = {