Skip to content

Commit daac956

Browse files
committed
Fix issue reported by pre-commit.
Signed-off-by: Wil Kong <[email protected]>
1 parent 4223f69 commit daac956

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -660,14 +660,11 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor
660660
# Keep tensor on its original device to avoid unnecessary transfers
661661
value = value.clone().detach()
662662
else:
663-
if self.device.type == "cuda":
664-
# Place scalar metrics on CPU to avoid CPU-GPU transfer and synchronization.
665-
# `torch.tensor(value, device="cuda")` contains such synchronization, while the metric
666-
# itself is only used on the CPU side. So placing metric on CPU for scalar inputs is more efficient.
667-
device = "cpu"
668-
else:
669-
# For non-CUDA devices, maintain original behavior
670-
device = self.device
663+
# Place scalar metrics on CPU to avoid CPU-GPU transfer and synchronization.
664+
# `torch.tensor(value, device="cuda")` contains such synchronization, while the metric
665+
# itself is only used on the CPU side. So placing metric on CPU for scalar inputs is more efficient.
666+
# For non-CUDA devices, maintain original behavior
667+
device = "cpu" if self.device.type == "cuda" else self.device
671668
value = torch.tensor(value, device=device, dtype=_get_default_dtype())
672669

673670
if not torch.numel(value) == 1:

0 commit comments

Comments
 (0)