@@ -398,6 +398,7 @@ class Monitoring:
398398
399399 def __init__ (self , output_folder = "" , name = "" , interval = 0.25 , device = "cuda" , plot_monitoring = True , show_steps_in_plots = True ):
400400 self .device = device
401+ self .device_name = None
401402 self .output_folder = output_folder
402403 if not name :
403404 self .name = output_folder
@@ -408,13 +409,25 @@ def __init__(self, output_folder="", name="", interval=0.25, device="cuda", plot
408409 self .will_plot_monitoring = plot_monitoring
409410 if self .will_plot_monitoring :
410411 pass
412+ self .device = self .device if self .device else 0
413+ if self .device == "cuda" or self .device == "gpu" :
414+ self .device = 0
415+ elif self .device .startswith ("cuda:" ):
416+ self .device = int (self .device .split (":" )[1 ])
417+ if self .device != "cpu" and isinstance (self .device , int ):
418+ num_gpus = get_num_gpus ()
419+ if self .device > num_gpus :
420+ raise ValueError (f"GPU { self .device } doesn't exist, only { num_gpus } GPUs available" )
421+ self .device = ALL_GPU_INDICES [self .device ]
422+ elif self .device != "cpu" :
423+ raise ValueError (f"Device { self .device } doesn't exist, use 'gpu', 'cpu', 'cuda', 'cuda:0' or '0' for example" )
411424
412425 def _finish_step (self , monitoring , step_values , step = 0 , start = 0 ):
413426 for i in step_values :
414427 if i not in monitoring :
415428 monitoring [i ] = []
416429 monitoring [i ].extend (step_values [i ])
417- if self .steps and len (self .steps )> 0 :
430+ if self .steps and len (self .steps ) > 0 and step < len ( self . steps ) :
418431 if "steps" not in monitoring :
419432 monitoring ["steps" ] = []
420433 if "steps_end" not in monitoring :
@@ -464,9 +477,11 @@ def _monitor(self):
464477 start = time .time () - monitoring ["time_points" ][- 1 ]
465478 if "device" in monitoring and monitoring ["device" ] != (pynvml .nvmlDeviceGetName (handle ) if handle else "cpu" ):
466479 raise ValueError ("The device used in the monitoring is different from the one specified in the current monitoring" )
480+ self .device_name = monitoring .get ("device" , "cpu" )
467481 else :
468482 monitoring = dict ()
469483 monitoring ["device" ] = pynvml .nvmlDeviceGetName (handle ) if handle else "cpu"
484+ self .device_name = monitoring ["device" ]
470485 start = time .time ()
471486 step = 0
472487 step_monitoring = dict ()
@@ -498,12 +513,6 @@ def start(self, steps=None):
498513 steps: list of str
499514 List of steps to monitor
500515 """
501- self .device = self .device if self .device else 0
502- if self .device == "cuda" or self .device == "gpu" :
503- self .device = 0
504- if self .device != "cpu" :
505- get_num_gpus ()
506- self .device = ALL_GPU_INDICES [self .device ]
507516 self .event_stop = threading .Event ()
508517 self .event_next = threading .Event ()
509518 self .event_error = threading .Event ()
@@ -530,6 +539,18 @@ def stop(self, error=False):
530539 self .event_stop .set ()
531540 self .monitoring_thread .join ()
532541
542+ def get_device_name (self ):
543+ if self .device_name is None :
544+ if self .device != "cpu" :
545+ pynvml .nvmlInit ()
546+ handle = pynvml .nvmlDeviceGetHandleByIndex (self .device )
547+ else :
548+ handle = None
549+ self .device_name = pynvml .nvmlDeviceGetName (handle ) if handle else "cpu"
550+ if handle :
551+ pynvml .nvmlShutdown ()
552+ return self .device_name
553+
533554 def plot_hardware (self , values , times , output_folder , ylabel = "RAM Usage" , lims = None , steps = None ):
534555 import matplotlib .pyplot as plt
535556
0 commit comments