Skip to content

Commit a9047fe

Browse files
committed
fix: fix default device logics.
1 parent e6f8097 commit a9047fe

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,8 +1548,9 @@ def forward(self, x):
15481548
device = self.device
15491549
if self.device.type != "cuda":
15501550
default_device = torch.device(default_device) if isinstance(default_device, str) else default_device
1551-
if default_device.type != "cuda":
1552-
raise ValueError(
1551+
1552+
if not torch.cuda.is_available() or default_device.type != "cuda":
1553+
raise MisconfigurationException(
15531554
f"TensorRT only supports CUDA devices. The current device is {self.device}."
15541555
f" Please set the `default_device` argument to a CUDA device."
15551556
)

0 commit comments

Comments
 (0)