@@ -132,7 +132,6 @@ static int setup_lib(void) {
132132 const char * ver ;
133133 CUresult err ;
134134 int res , tmp ;
135- int search_version = 0 ;
136135
137136 if (!setup_done ) {
138137 res = load_libcuda (global_err );
@@ -141,55 +140,24 @@ static int setup_lib(void) {
141140 err = cuInit (0 );
142141 if (err != CUDA_SUCCESS )
143142 return error_cuda (global_err , "cuInit" , err );
144- ver = getenv ("GPUARRAY_CUDA_VERSION" );
145- if (ver == NULL || strlen (ver ) != 2 ) {
146- err = cuDriverGetVersion (& tmp );
147- if (err != CUDA_SUCCESS )
148- return error_set (global_err , GA_IMPL_ERROR , "cuDriverGetVersion failed" );
149- major = tmp / 1000 ;
150- minor = (tmp / 10 ) % 10 ;
151- #if defined(WIN32 ) || defined(_WIN32 ) || defined(WIN64 ) || defined(_WIN64 ) || defined(__APPLE__ )
152- /* We will dynamically search the right CUDA version only on Windows and Macintosh systems,
153- and only if user has not explicitely specified GPUARRAY_CUDA_VERSION. */
154- search_version = 1 ;
155- #endif
156- } else {
157- major = ver [0 ] - '0' ;
158- minor = ver [1 ] - '0' ;
159- }
160- /* NB: next line will cause problems if a CUDA 10.0 (or 9.11) is released in the future. */
161- if (major > 9 || major < 0 || minor > 9 || minor < 0 )
162- return error_fmt (global_err , GA_VALUE_ERROR , "Invalid cuda version: %d.%d" , major , minor );
163- if (!search_version ) {
164- res = load_libnvrtc (major , minor , global_err );
165- } else {
166- /* First case in next array is reserved to eventually receive the version returned by cuDriverGetVersion(). */
167- int versions [] = {-1 , 80 , 75 };
168- int versions_length = sizeof (versions ) / sizeof (int );
169- int current_version = major * 10 + minor ;
143+ err = cuDriverGetVersion (& tmp );
144+ if (err != CUDA_SUCCESS )
145+ return error_set (global_err , GA_IMPL_ERROR , "cuDriverGetVersion failed" );
146+ major = tmp / 1000 ;
147+ minor = (tmp / 10 ) % 10 ;
148+ /* Let's try to load a nvrtc corresponding to detected CUDA version. */
149+ res = load_libnvrtc (major , minor , global_err );
150+ if (res != GA_NO_ERROR ) {
151+ /* Else, let's try to find a nvrtc corresponding to supported CUDA versions. */
152+ int versions [][2 ] = {{8 , 0 }, {7 , 5 }, {7 , 0 }};
153+ int versions_length = sizeof (versions ) / (2 * sizeof (int ));
170154 int i = 0 ;
171- for (i = 1 ; i < versions_length && versions [i ] != current_version ; ++ i );
172- if (i == versions_length ) {
173- /* Current version not found in the list of versions. We add it at top of the list. */
174- versions [0 ] = current_version ;
175- /* We will iterate on versions from the first. */
176- i = 0 ;
177- } else {
178- /* Current version found in the list of known versions. No need to add it to the list. */
179- i = 1 ;
180- };
181155 do {
182- major = versions [i ] / 10 ;
183- minor = versions [i ] % 10 ;
156+ major = versions [i ][ 0 ] ;
157+ minor = versions [i ][ 1 ] ;
184158 res = load_libnvrtc (major , minor , global_err );
185159 ++ i ;
186160 } while (res != GA_NO_ERROR && i < versions_length );
187- #ifdef DEBUG
188- if (res == GA_NO_ERROR )
189- fprintf (stderr , "Detected CUDA %d.%d.\n" , major , minor );
190- else
191- fprintf (stderr , "Unable to detect a CUDA version.\n" );
192- #endif
193161 }
194162 if (res != GA_NO_ERROR )
195163 return res ;
0 commit comments