File tree Expand file tree Collapse file tree 1 file changed +7
-12
lines changed
Expand file tree Collapse file tree 1 file changed +7
-12
lines changed Original file line number Diff line number Diff line change @@ -656,18 +656,13 @@ def check_cuda_runtime():
656656 driver_version = ctypes .c_int ()
657657 runtime_version = ctypes .c_int ()
658658
659- if cuda .cudaDriverGetVersion (ctypes .byref (driver_version )) == 0 and \
660- cuda .cudaRuntimeGetVersion (ctypes .byref (runtime_version )) == 0 :
661- driver_version = driver_version .value
662- runtime_version = runtime_version .value
663-
664- if driver_version == 0 :
665- # cudart present but no driver detected. Likely isolation
666- # run such as version check or within a docker build.
667- return
668-
669- driver_v = parse (str (driver_version / 1000 ))
670- runtime_v = parse (str (runtime_version / 1000 ))
659+ # Check the get*Version call succeeds and is a non-zero value
660+ if cuda .cudaDriverGetVersion (ctypes .byref (driver_version )) == 0 \
661+ and cuda .cudaRuntimeGetVersion (ctypes .byref (runtime_version )) == 0 \
662+ and driver_version .value :
663+
664+ driver_v = parse (str (driver_version .value / 1000 ))
665+ runtime_v = parse (str (runtime_version .value / 1000 ))
671666 # First check the "major" version, known to be incompatible
672667 if driver_v .major < runtime_v .major :
673668 raise RuntimeError (
You can’t perform that action at this time.
0 commit comments