Skip to content

Commit ecc6a9b

Browse files
committed
refactor: try catch a more specific exception during log_metrics.
1 parent e21b172 commit ecc6a9b

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
211211
else:
212212
try:
213213
self.experiment.add_scalar(k, v, step)
214-
# TODO(fabric): specify the possible exception
215-
except Exception as ex:
214+
except (NotImplementedError, ValueError) as ex:
216215
raise ValueError(
217216
f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
218217
) from ex

tests/tests_pytorch/loggers/test_tensorboard.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import re
1516
from argparse import Namespace
1617
from unittest import mock
1718
from unittest.mock import Mock
@@ -157,6 +158,18 @@ def test_tensorboard_log_metrics(tmp_path, step_idx):
157158
logger.log_metrics(metrics, step_idx)
158159

159160

161+
@pytest.mark.parametrize("value", [[1], "x", None])
162+
def test_tensorboard_log_metrics_exception_message(tmp_path, value):
163+
logger = TensorBoardLogger(tmp_path)
164+
with pytest.raises(
165+
ValueError,
166+
match=re.escape(
167+
f"you tried to log {value} which is currently not supported. Try a dict or a scalar/tensor.",
168+
),
169+
):
170+
logger.log_metrics(metrics={"metric": value})
171+
172+
160173
def test_tensorboard_log_hyperparams(tmp_path):
161174
logger = TensorBoardLogger(tmp_path)
162175
hparams = {

0 commit comments

Comments
 (0)