Skip to content

Commit b78169e

Browse files
committed
Remove support for GPUARRAY_CUDA_VERSION.
Try to load detected CUDA by default, else look for supported versions. Add support for >9 versions (e.g. CUDA 10.1).
1 parent 63d262c commit b78169e

File tree

3 files changed

+32
-61
lines changed

3 files changed

+32
-61
lines changed

src/gpuarray_buffer_cuda.c

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ static int setup_lib(void) {
132132
const char *ver;
133133
CUresult err;
134134
int res, tmp;
135-
int search_version = 0;
136135

137136
if (!setup_done) {
138137
res = load_libcuda(global_err);
@@ -141,55 +140,24 @@ static int setup_lib(void) {
141140
err = cuInit(0);
142141
if (err != CUDA_SUCCESS)
143142
return error_cuda(global_err, "cuInit", err);
144-
ver = getenv("GPUARRAY_CUDA_VERSION");
145-
if (ver == NULL || strlen(ver) != 2) {
146-
err = cuDriverGetVersion(&tmp);
147-
if (err != CUDA_SUCCESS)
148-
return error_set(global_err, GA_IMPL_ERROR, "cuDriverGetVersion failed");
149-
major = tmp / 1000;
150-
minor = (tmp / 10) % 10;
151-
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) || defined(__APPLE__)
152-
/* We will dynamically search the right CUDA version only on Windows and Macintosh systems,
153-
and only if user has not explicitely specified GPUARRAY_CUDA_VERSION. */
154-
search_version = 1;
155-
#endif
156-
} else {
157-
major = ver[0] - '0';
158-
minor = ver[1] - '0';
159-
}
160-
/* NB: next line will cause problems if a CUDA 10.0 (or 9.11) is released in the future. */
161-
if (major > 9 || major < 0 || minor > 9 || minor < 0)
162-
return error_fmt(global_err, GA_VALUE_ERROR, "Invalid cuda version: %d.%d", major, minor);
163-
if (!search_version) {
164-
res = load_libnvrtc(major, minor, global_err);
165-
} else {
166-
/* First case in next array is reserved to eventually receive the version returned by cuDriverGetVersion(). */
167-
int versions[] = {-1, 80, 75};
168-
int versions_length = sizeof(versions) / sizeof(int);
169-
int current_version = major * 10 + minor;
143+
err = cuDriverGetVersion(&tmp);
144+
if (err != CUDA_SUCCESS)
145+
return error_set(global_err, GA_IMPL_ERROR, "cuDriverGetVersion failed");
146+
major = tmp / 1000;
147+
minor = (tmp / 10) % 10;
148+
/* Let's try to load a nvrtc corresponding to detected CUDA version. */
149+
res = load_libnvrtc(major, minor, global_err);
150+
if (res != GA_NO_ERROR) {
151+
/* Else, let's try to find a nvrtc corresponding to supported CUDA versions. */
152+
int versions[][2] = {{8, 0}, {7, 5}, {7, 0}};
153+
int versions_length = sizeof(versions) / (2 * sizeof(int));
170154
int i = 0;
171-
for (i = 1; i < versions_length && versions[i] != current_version; ++i);
172-
if (i == versions_length) {
173-
/* Current version not found in the list of versions. We add it at top of the list. */
174-
versions[0] = current_version;
175-
/* We will iterate on versions from the first. */
176-
i = 0;
177-
} else {
178-
/* Current version found in the list of known versions. No need to add it to the list. */
179-
i = 1;
180-
};
181155
do {
182-
major = versions[i] / 10;
183-
minor = versions[i] % 10;
156+
major = versions[i][0];
157+
minor = versions[i][1];
184158
res = load_libnvrtc(major, minor, global_err);
185159
++i;
186160
} while(res != GA_NO_ERROR && i < versions_length);
187-
#ifdef DEBUG
188-
if (res == GA_NO_ERROR)
189-
fprintf(stderr, "Detected CUDA %d.%d.\n", major, minor);
190-
else
191-
fprintf(stderr, "Unable to detect a CUDA version.\n");
192-
#endif
193161
}
194162
if (res != GA_NO_ERROR)
195163
return res;

src/loaders/libcublas.c

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,22 @@ int load_libcublas(int major, int minor, error *e) {
4545

4646
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
4747
{
48-
static const char DIGITS[] = "0123456789";
49-
char libname[] = "cublas64_??.dll";
48+
const char* libname_pattern = "cublas64_%d%d.dll";
49+
char libname[64];
5050

5151
#ifdef DEBUG
5252
fprintf(stderr, "Loading cuBLAS %d.%d.\n", major, minor);
5353
#endif
54-
libname[9] = DIGITS[major];
55-
libname[10] = DIGITS[minor];
54+
sprintf(libname, libname_pattern, major, minor);
5655

5756
lib = ga_load_library(libname, e);
5857
}
5958
#else /* Unix */
6059
#ifdef __APPLE__
6160
{
62-
static const char DIGITS[] = "0123456789";
63-
char libname[] = "/Developer/NVIDIA/CUDA-?.?/lib/libcublas.dylib";
64-
libname[23] = DIGITS[major];
65-
libname[25] = DIGITS[minor];
61+
const char* libname_pattern = "/Developer/NVIDIA/CUDA-%d.%d/lib/libcublas.dylib";
62+
char libname[128];
63+
sprintf(libname, libname_pattern, major, minor);
6664
lib = ga_load_library(libname, e);
6765
}
6866
#else

src/loaders/libnvrtc.c

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
#include <stdlib.h>
2+
#ifdef DEBUG
3+
/* For fprintf and stderr. */
4+
#include <stdio.h>
5+
#endif
26

37
#include "libcuda.h"
48
#include "libnvrtc.h"
@@ -27,22 +31,23 @@ int load_libnvrtc(int major, int minor, error *e) {
2731

2832
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
2933
{
30-
static const char DIGITS[] = "0123456789";
31-
char libname[] = "nvrtc64_??.dll";
34+
const char* libname_pattern = "nvrtc64_%d%d.dll";
35+
char libname[64];
3236

33-
libname[8] = DIGITS[major];
34-
libname[9] = DIGITS[minor];
37+
#ifdef DEBUG
38+
fprintf(stderr, "Loading nvrtc %d.%d.\n", major, minor);
39+
#endif
40+
sprintf(libname, libname_pattern, major, minor);
3541

3642
lib = ga_load_library(libname, e);
3743
}
3844
#else /* Unix */
3945
#ifdef __APPLE__
4046
{
41-
static const char DIGITS[] = "0123456789";
4247
/* Try the usual fullpath first */
43-
char libname[] = "/Developer/NVIDIA/CUDA-?.?/lib/libnvrtc.dylib";
44-
libname[23] = DIGITS[major];
45-
libname[25] = DIGITS[minor];
48+
const char* libname_pattern = "/Developer/NVIDIA/CUDA-%d.%d/lib/libnvrtc.dylib";
49+
char libname[128];
50+
sprintf(libname, libname_pattern, major, minor);
4651
lib = ga_load_library(libname, e);
4752
}
4853
#else

0 commit comments

Comments
 (0)