Skip to content

Commit 72f82eb

Browse files
awaelchliBorda
authored andcommitted
Avoid warning when cloning tensor in self.log (#14599)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 10e09c6 commit 72f82eb

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Reset the dataloaders on OOM failure in batch size finder to use the last successful batch size ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))
1212
- Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))
13+
- Fixed an issue where `self.log`-ing a tensor would create a user warning from PyTorch about cloning tensors ([#14599](https://github.com/Lightning-AI/lightning/pull/14599))
1314
- Fixed compatibility when `torch.distributed` is not available ([#14454](https://github.com/Lightning-AI/lightning/pull/14454))
1415
- Fixed torchscript error with ensembles of LightningModules ([#14657](https://github.com/Lightning-AI/lightning/pull/14657))
1516

src/pytorch_lightning/core/module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,11 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
581581
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
582582

583583
def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
584-
value = torch.tensor(value, device=self.device)
584+
value = (
585+
value.clone().detach().to(self.device)
586+
if isinstance(value, torch.Tensor)
587+
else torch.tensor(value, device=self.device)
588+
)
585589
if not torch.numel(value) == 1:
586590
raise ValueError(
587591
f"`self.log({name}, {value})` was called, but the tensor must have a single element."

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pytorch_lightning.trainer.states import RunningStage
3333
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3434
from tests_pytorch.helpers.runif import RunIf
35+
from tests_pytorch.helpers.utils import no_warning_call
3536

3637

3738
def test__training_step__log(tmpdir):
@@ -626,6 +627,21 @@ def training_step(self, *args):
626627
trainer.fit(model)
627628

628629

630+
def test_log_tensor_and_clone_no_torch_warning(tmpdir):
631+
"""Regression test for issue https://github.com/Lightning-AI/lightning/issues/14594."""
632+
633+
class TestModel(BoringModel):
634+
def training_step(self, *args):
635+
self.log("foo", torch.tensor(1))
636+
return super().training_step(*args)
637+
638+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
639+
model = TestModel()
640+
match = r"recommended.*.clone\(\).detach\(\)"
641+
with no_warning_call(UserWarning, match=match):
642+
trainer.fit(model)
643+
644+
629645
def test_logging_raises(tmpdir):
630646
class TestModel(BoringModel):
631647
def training_step(self, batch, batch_idx):

0 commit comments

Comments
 (0)