Skip to content

Commit a3c3c5d

Browse files
rohitgr7awaelchli
authored andcommitted
Squeeze tensor while logging (#14489)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 62cb1de commit a3c3c5d

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/pytorch_lightning/core/module.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,7 @@ def log(
419419
" but it should not contain information about `dataloader_idx`"
420420
)
421421

422-
value = apply_to_collection(value, numbers.Number, self.__to_tensor)
423-
apply_to_collection(value, torch.Tensor, self.__check_numel_1, name)
422+
value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)
424423

425424
if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
426425
# if we started a new epoch (running its first batch) the hook name has changed
@@ -552,16 +551,15 @@ def __check_not_nested(value: dict, name: str) -> None:
552551
def __check_allowed(v: Any, name: str, value: Any) -> None:
553552
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
554553

555-
def __to_tensor(self, value: numbers.Number) -> Tensor:
556-
return torch.tensor(value, device=self.device)
557-
558-
@staticmethod
559-
def __check_numel_1(value: Tensor, name: str) -> None:
554+
def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
555+
value = torch.tensor(value, device=self.device)
560556
if not torch.numel(value) == 1:
561557
raise ValueError(
562558
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
563559
f" You can try doing `self.log({name}, {value}.mean())`"
564560
)
561+
value = value.squeeze()
562+
return value
565563

566564
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
567565
"""Override this method to change the default behaviour of ``log_grad_norm``.

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from pytorch_lightning import callbacks, Trainer
2929
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
3030
from pytorch_lightning.core.module import LightningModule
31-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
31+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
32+
from pytorch_lightning.trainer.states import RunningStage
3233
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3334
from tests_pytorch.helpers.datasets import RandomDictDataset
3435
from tests_pytorch.helpers.runif import RunIf
@@ -837,3 +838,13 @@ def on_train_start(self):
837838

838839
assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
839840
assert trainer.max_epochs > 1
841+
842+
843+
def test_unsqueezed_tensor_logging():
844+
model = BoringModel()
845+
trainer = Trainer()
846+
trainer.state.stage = RunningStage.TRAINING
847+
model._current_fx_name = "training_step"
848+
model.trainer = trainer
849+
model.log("foo", torch.Tensor([1.2]))
850+
assert trainer.callback_metrics["foo"].ndim == 0

0 commit comments

Comments
 (0)