Skip to content

Commit c1746b2

Browse files
committed
feat: change to original device logic in setup.
1 parent 037a24b commit c1746b2

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

src/lightning/pytorch/trainer/setup.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,18 @@ def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str
142142

143143

144144
def _log_device_info(trainer: "pl.Trainer") -> None:
145-
if isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)):
146-
gpu_used = trainer.num_devices
147-
device_names = list({trainer.accelerator.device_name(d) for d in trainer.devices})
145+
if CUDAAccelerator.is_available():
146+
if isinstance(trainer.accelerator, CUDAAccelerator):
147+
device_name = list({CUDAAccelerator.device_name(d) for d in trainer.device_ids})
148+
else:
149+
device_name = CUDAAccelerator.device_name()
150+
elif MPSAccelerator.is_available():
151+
device_name = MPSAccelerator.device_name()
148152
else:
149-
gpu_used = 0
150-
device_names = "False"
151-
rank_zero_info(f"GPU available: {device_names}, using: {gpu_used} {'devices' if gpu_used else 'device'}.")
153+
device_name = str(False)
154+
155+
gpu_used = trainer.num_devices if isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) else 0
156+
rank_zero_info(f"GPU available: {device_name}, using: {gpu_used} devices.")
152157

153158
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
154159
rank_zero_info(f"TPU available: {XLAAccelerator.device_name()}, using: {num_tpu_cores} TPU cores")

0 commit comments

Comments
 (0)