@@ -129,7 +129,6 @@ static int setup_done = 0;
129129static int major = -1 ;
130130static int minor = -1 ;
131131static int setup_lib (void ) {
132- const char * ver ;
133132 CUresult err ;
134133 int res , tmp ;
135134
@@ -140,20 +139,28 @@ static int setup_lib(void) {
140139 err = cuInit (0 );
141140 if (err != CUDA_SUCCESS )
142141 return error_cuda (global_err , "cuInit" , err );
143- ver = getenv ("GPUARRAY_CUDA_VERSION" );
144- if (ver == NULL || strlen (ver ) != 2 ) {
145- err = cuDriverGetVersion (& tmp );
146- if (err != CUDA_SUCCESS )
147- return error_set (global_err , GA_IMPL_ERROR , "cuDriverGetVersion failed" );
148- major = tmp / 1000 ;
149- minor = (tmp / 10 ) % 10 ;
150- } else {
151- major = ver [0 ] - '0' ;
152- minor = ver [1 ] - '0' ;
153- }
154- if (major > 9 || major < 0 || minor > 9 || minor < 0 )
155- return error_fmt (global_err , GA_VALUE_ERROR , "Invalid cuda version: %d.%d" , major , minor );
142+ err = cuDriverGetVersion (& tmp );
143+ if (err != CUDA_SUCCESS )
144+ return error_set (global_err , GA_IMPL_ERROR , "cuDriverGetVersion failed" );
145+ major = tmp / 1000 ;
146+ minor = (tmp / 10 ) % 10 ;
147+ /* Let's try to load a nvrtc corresponding to detected CUDA version. */
156148 res = load_libnvrtc (major , minor , global_err );
149+ if (res != GA_NO_ERROR ) {
150+ /* Else, let's try to find a nvrtc corresponding to supported CUDA versions. */
151+ int versions [][2 ] = {{8 , 0 }, {7 , 5 }, {7 , 0 }};
152+ int versions_length = sizeof (versions ) / sizeof (versions [0 ]);
153+ int i = 0 ;
154+ /* Skip versions that are higher or equal to the driver version */
155+ while (versions [i ][0 ] > major ||
156+ (versions [i ][0 ] == major && versions [i ][1 ] >= minor )) i ++ ;
157+ do {
158+ major = versions [i ][0 ];
159+ minor = versions [i ][1 ];
160+ res = load_libnvrtc (major , minor , global_err );
161+ i ++ ;
162+ } while (res != GA_NO_ERROR && i < versions_length );
163+ }
157164 if (res != GA_NO_ERROR )
158165 return res ;
159166 setup_done = 1 ;
0 commit comments