@@ -1992,40 +1992,52 @@ def get_device_count():
19921992 return len (get_gpu_device_list ())
19931993
19941994
1995- def get_device_memory ():
1996- "get gpu memory"
1997- memory = 0
1995+ def get_device_memory_str ():
19981996 with tempfile .TemporaryDirectory () as temp_dirname :
19991997 suffix = ".exe" if is_windows () else ""
2000- # TODO: Use NRSU because we can't assume nvidia-smi across all platforms.
20011998 cmd = " " .join ([
2002- "nvidia-smi" + suffix , "--query-gpu=memory.total" ,
2003- "--format=csv,noheader"
2004- ])
2005- # Try to get memory from nvidia-smi first, if failed, fallback to system memory from /proc/meminfo
2006- # This fallback is needed for systems with unified memory (e.g. DGX Spark)
1999+ "nvidia-smi" + suffix , "--query-gpu=memory.total,memory.reserved,memory.used,memory.free" ,
2000+ "--format=csv,noheader"
2001+ ])
2002+ output = check_output (cmd , shell = True , cwd = temp_dirname )
2003+ return output .strip ()
2004+
2005+ def get_device_memory ():
2006+ "get gpu memory"
2007+ memory = 0
2008+ # Try to get memory from nvidia-smi first, if failed, fallback to system memory from /proc/meminfo
2009+ # This fallback is needed for systems with unified memory (e.g. DGX Spark)
2010+ try :
2011+ output = get_device_memory_str ()
2012+ memory_str = output .strip ().split ()[0 ]
2013+ # Check if nvidia-smi returned a valid numeric value
2014+ if "N/A" in memory_str :
2015+ raise ValueError ("nvidia-smi returned invalid memory info" )
2016+ memory = int (memory_str )
2017+ except (sp .CalledProcessError , ValueError , IndexError ):
2018+ # Fallback to system memory from /proc/meminfo (in kB, convert to MiB)
20072019 try :
2008- output = check_output (cmd , shell = True , cwd = temp_dirname )
2009- memory_str = output .strip ().split ()[0 ]
2010- # Check if nvidia-smi returned a valid numeric value
2011- if "N/A" in memory_str :
2012- raise ValueError ("nvidia-smi returned invalid memory info" )
2013- memory = int (memory_str )
2014- except (sp .CalledProcessError , ValueError , IndexError ):
2015- # Fallback to system memory from /proc/meminfo (in kB, convert to MiB)
2016- try :
2017- with open ("/proc/meminfo" , "r" ) as f :
2018- for line in f :
2019- if line .startswith ("MemTotal:" ):
2020- memory = int (
2021- line .split ()[1 ]) // 1024 # Convert kB to MiB
2022- break
2023- except :
2024- memory = 8192 # Default 8GB if all else fails
2020+ with open ("/proc/meminfo" , "r" ) as f :
2021+ for line in f :
2022+ if line .startswith ("MemTotal:" ):
2023+ memory = int (
2024+ line .split ()[1 ]) // 1024 # Convert kB to MiB
2025+ break
2026+ except :
2027+ memory = 8192 # Default 8GB if all else fails
20252028
20262029 return memory
20272030
20282031
2032+ def print_device_memory ():
2033+ memory_str = get_device_memory_str ()
2034+ print (f"Device Memory:\n total: reserved: used: free: \n { memory_str } " )
2035+ torch .cuda .empty_cache ()
2036+ import gc
2037+ gc .collect ()
2038+ memory_str = get_device_memory_str ()
2039+ print (f"Device Memory:\n total: reserved: used: free: \n { memory_str } " )
2040+
20292041def pytest_addoption (parser ):
20302042 parser .addoption (
20312043 "--test-list" ,
0 commit comments