Skip to content

Commit 3851846

Browse files
author
dominicgkerr
committed
Log multi-value tensors as histograms
1 parent 93b5e9c commit 3851846

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,22 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
205205
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
206206

207207
for k, v in metrics.items():
208-
if isinstance(v, Tensor):
208+
if isinstance(v, Tensor) and v.ndim == 0:
209209
v = v.item()
210210

211-
if isinstance(v, dict):
212-
self.experiment.add_scalars(k, v, step)
213-
else:
214-
try:
211+
try:
212+
if isinstance(v, dict):
213+
self.experiment.add_scalars(k, v, step)
214+
elif isinstance(v, Tensor):
215+
self.experiment.add_histogram(k, v, step)
216+
else:
215217
self.experiment.add_scalar(k, v, step)
216-
# TODO(fabric): specify the possible exception
217-
except Exception as ex:
218-
raise ValueError(
219-
f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
220-
) from ex
218+
219+
# TODO(fabric): specify the possible exception
220+
except Exception as ex:
221+
raise ValueError(
222+
f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
223+
) from ex
221224

222225
@override
223226
@rank_zero_only

tests/tests_pytorch/loggers/test_tensorboard.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ def name(self):
154154
@pytest.mark.parametrize("step_idx", [10, None])
155155
def test_tensorboard_log_metrics(tmp_path, step_idx):
156156
logger = TensorBoardLogger(tmp_path)
157-
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
157+
metrics = {
158+
"float": 0.3,
159+
"int": 1,
160+
"FloatTensor": torch.tensor(0.1),
161+
"IntTensor": torch.tensor(1),
162+
"Histogram": torch.tensor([10, 100, 1000])
163+
}
158164
logger.log_metrics(metrics, step_idx)
159165

160166

0 commit comments

Comments
 (0)