Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions nvitop/api/libnvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@

# Python Bindings for the NVIDIA Management Library (NVML)
# https://pypi.org/project/nvidia-ml-py
import pynvml as _pynvml
from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import
from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import

from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType
from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType, is_musa
from nvitop.api.utils import colored as __colored

_is_musa = is_musa()

if not _is_musa:
import pynvml as _pynvml
from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import
from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import
else:
import pymtml as _pynvml
from pymtml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import
from pymtml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import

if _TYPE_CHECKING:
from collections.abc import Callable as _Callable
Expand Down Expand Up @@ -540,7 +547,10 @@ def nvmlCheckReturn(retval: _Any, types: type | tuple[type, ...] | None = None,
# Patch function `nvmlDeviceGet{Compute,Graphics,MPSCompute}RunningProcesses`
if not _pynvml_installation_corrupted:
# pylint: disable-next=ungrouped-imports
from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject
if not _is_musa:
from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject
else:
from pymtml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject

def _nvmlLookupFunctionPointer(symbol: str) -> _Any | None:
try:
Expand Down Expand Up @@ -671,7 +681,11 @@ def __nvml_device_get_running_processes(

# First call to get the size
c_count = _ctypes.c_uint(0)
fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}')
try:
fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}')
except Exception:
return []

ret = fn(handle, _ctypes.byref(c_count), None)

if ret == NVML_SUCCESS:
Expand Down Expand Up @@ -876,7 +890,11 @@ def nvmlDeviceGetMemoryInfo( # pylint: disable=function-redefined
'function `nvmlDeviceGetMemoryInfo`.',
)

fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}')
try:
fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}')
except Exception:
return NA

ret = fn(handle, _ctypes.byref(c_memory))
if ret != NVML_SUCCESS:
raise NVMLError(ret)
Expand Down Expand Up @@ -952,15 +970,21 @@ def nvmlDeviceGetTemperature( # pylint: disable=function-redefined
c_temp_v1.version = nvmlTemperature_v1
# pylint: disable-next=attribute-defined-outside-init
c_temp_v1.sensorType = _ctypes.c_uint(sensor)
fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperatureV')
try:
fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetTemperatureV{version_suffix}')
except Exception:
return NA
ret = fn(handle, _ctypes.byref(c_temp_v1))
if ret != NVML_SUCCESS:
raise NVMLError(ret)
return int(c_temp_v1.temperature)

if version_suffix == '':
c_temp = _ctypes.c_uint(0)
fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature')
try:
fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature')
except Exception:
return NA
ret = fn(handle, _ctypes.c_uint(sensor), _ctypes.byref(c_temp))
if ret != NVML_SUCCESS:
raise NVMLError(ret)
Expand Down
10 changes: 10 additions & 0 deletions nvitop/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,16 @@ def cache_deactivate(self: object) -> None:
wrapped.cache_deactivate = cache_deactivate # type: ignore[attr-defined]
return wrapped # type: ignore[return-value]

def is_musa() -> bool:
"""Check if the current Python interpreter is Musa."""
try:
import pymtml # noqa: F401
pymtml.nvmlInit()
pymtml.nvmlShutdown()
except Exception:
return False

return True

if __name__ == '__main__':
import doctest
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ classifiers = [
dependencies = [
# Sync with nvitop/version.py and requirements.txt
"nvidia-ml-py >= 11.450.51, < 13.591.0a0",
"mthreads-ml-py >= 2.2.1",
"psutil >= 5.6.6",
"colorama >= 0.4.0; platform_system == 'Windows'",
"windows-curses >= 2.2.0; platform_system == 'Windows'",
Expand Down