diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 020f66323d..88d1b3304c 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -656,13 +656,14 @@ def check_cuda_runtime(): driver_version = ctypes.c_int() runtime_version = ctypes.c_int() - if cuda.cudaDriverGetVersion(ctypes.byref(driver_version)) == 0 and \ - cuda.cudaRuntimeGetVersion(ctypes.byref(runtime_version)) == 0: - driver_version = driver_version.value - runtime_version = runtime_version.value - - driver_v = parse(str(driver_version/1000)) - runtime_v = parse(str(runtime_version/1000)) + # Check the get*Version call succeeds and is a non-zero value + call_success = cuda.cudaDriverGetVersion(ctypes.byref(driver_version)) == 0 + call_success &= cuda.cudaRuntimeGetVersion(ctypes.byref(runtime_version)) == 0 + call_success &= bool(driver_version.value) + + if call_success: + driver_v = parse(str(driver_version.value/1000)) + runtime_v = parse(str(runtime_version.value/1000)) # First check the "major" version, known to be incompatible if driver_v.major < runtime_v.major: raise RuntimeError(