diff --git a/modelopt/torch/utils/memory_monitor.py b/modelopt/torch/utils/memory_monitor.py index 2b7558537..94ed1cab8 100644 --- a/modelopt/torch/utils/memory_monitor.py +++ b/modelopt/torch/utils/memory_monitor.py @@ -131,7 +131,7 @@ def stop(self): nvmlShutdown() -def launch_memory_monitor(monitor_interval: float = 1.0) -> GPUMemoryMonitor: +def launch_memory_monitor(monitor_interval: float = 1.0) -> GPUMemoryMonitor | None: """Launch a GPU memory monitor in a separate thread. Args: @@ -140,6 +140,11 @@ def launch_memory_monitor(monitor_interval: float = 1.0) -> GPUMemoryMonitor: Returns: GPUMemoryMonitor: The monitor instance that was launched """ + try: + nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(0)) + except Exception as e: + print(f"Failed to get GPU memory info: {e}. Stopping GPU memory monitor.") + return None monitor = GPUMemoryMonitor(monitor_interval) monitor.start() atexit.register(monitor.stop) # Ensure the monitor stops when the program exits