Skip to content

Commit 231ee68

Browse files
committed
improve monitoring device
1 parent 0d37de8 commit 231ee68

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

ssak/utils/monitoring.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,17 @@ def __init__(self, output_folder="", name="", interval=0.25, device="cuda", plot
410410
if self.will_plot_monitoring:
411411
pass
412412
self.device = self.device if self.device else 0
413-
if self.device == "cuda" or self.device == "gpu":
413+
if self.device=="cuda" or self.device == "gpu":
414414
self.device = 0
415-
if self.device != "cpu":
416-
get_num_gpus()
415+
elif self.device.startswith("cuda:"):
416+
self.device = int(self.device.split(":")[1])
417+
if self.device != "cpu" and isinstance(self.device, int):
418+
num_gpus = get_num_gpus()
419+
if self.device>num_gpus:
420+
raise ValueError(f"GPU {self.device} doesn't exist, only {num_gpus} GPUs available")
417421
self.device = ALL_GPU_INDICES[self.device]
422+
elif self.device != "cpu":
423+
raise ValueError(f"Device {self.device} doesn't exist, use 'gpu', 'cpu', 'cuda', 'cuda:0' or '0' for example")
418424

419425
def _finish_step(self, monitoring, step_values, step=0, start=0):
420426
for i in step_values:

0 commit comments

Comments
 (0)