Skip to content

Commit 388fc11

Browse files
authored
Merge pull request #476 from notoraptor/improve-cuda-detection-win
Loop over supported CUDA versions to find installed CUDA on Windows and Mac.
2 parents 2820925 + e97823a commit 388fc11

File tree

3 files changed

+57
-36
lines changed

3 files changed

+57
-36
lines changed

src/gpuarray_buffer_cuda.c

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ static int setup_done = 0;
129129
static int major = -1;
130130
static int minor = -1;
131131
static int setup_lib(void) {
132-
const char *ver;
133132
CUresult err;
134133
int res, tmp;
135134

@@ -140,20 +139,28 @@ static int setup_lib(void) {
140139
err = cuInit(0);
141140
if (err != CUDA_SUCCESS)
142141
return error_cuda(global_err, "cuInit", err);
143-
ver = getenv("GPUARRAY_CUDA_VERSION");
144-
if (ver == NULL || strlen(ver) != 2) {
145-
err = cuDriverGetVersion(&tmp);
146-
if (err != CUDA_SUCCESS)
147-
return error_set(global_err, GA_IMPL_ERROR, "cuDriverGetVersion failed");
148-
major = tmp / 1000;
149-
minor = (tmp / 10) % 10;
150-
} else {
151-
major = ver[0] - '0';
152-
minor = ver[1] - '0';
153-
}
154-
if (major > 9 || major < 0 || minor > 9 || minor < 0)
155-
return error_fmt(global_err, GA_VALUE_ERROR, "Invalid cuda version: %d.%d", major, minor);
142+
err = cuDriverGetVersion(&tmp);
143+
if (err != CUDA_SUCCESS)
144+
return error_set(global_err, GA_IMPL_ERROR, "cuDriverGetVersion failed");
145+
major = tmp / 1000;
146+
minor = (tmp / 10) % 10;
147+
/* Let's try to load a nvrtc corresponding to detected CUDA version. */
156148
res = load_libnvrtc(major, minor, global_err);
149+
if (res != GA_NO_ERROR) {
150+
/* Else, let's try to find a nvrtc corresponding to supported CUDA versions. */
151+
int versions[][2] = {{8, 0}, {7, 5}, {7, 0}};
152+
int versions_length = sizeof(versions) / sizeof(versions[0]);
153+
int i = 0;
154+
/* Skip versions that are higher or equal to the driver version */
155+
while (versions[i][0] > major ||
156+
(versions[i][0] == major && versions[i][1] >= minor)) i++;
157+
do {
158+
major = versions[i][0];
159+
minor = versions[i][1];
160+
res = load_libnvrtc(major, minor, global_err);
161+
i++;
162+
} while (res != GA_NO_ERROR && i < versions_length);
163+
}
157164
if (res != GA_NO_ERROR)
158165
return res;
159166
setup_done = 1;

src/loaders/libcublas.c

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include <stdlib.h>
1+
/* To be able to use snprintf with any compiler including MSVC2008. */
2+
#include <private_config.h>
23

34
#include "libcublas.h"
45
#include "dyn_load.h"
@@ -41,21 +42,27 @@ int load_libcublas(int major, int minor, error *e) {
4142

4243
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
4344
{
44-
static const char DIGITS[] = "0123456789";
45-
char libname[] = "cublas64_??.dll";
46-
47-
libname[9] = DIGITS[major];
48-
libname[10] = DIGITS[minor];
49-
45+
char libname[64];
46+
int n;
47+
#ifdef DEBUG
48+
fprintf(stderr, "Loading cuBLAS %d.%d.\n", major, minor);
49+
#endif
50+
n = snprintf(libname, sizeof(libname), "cublas64_%d%d.dll", major, minor);
51+
if (n < 0 || n >= sizeof(libname))
52+
return error_set(e, GA_SYS_ERROR, "snprintf");
5053
lib = ga_load_library(libname, e);
5154
}
5255
#else /* Unix */
5356
#ifdef __APPLE__
5457
{
55-
static const char DIGITS[] = "0123456789";
56-
char libname[] = "/Developer/NVIDIA/CUDA-?.?/lib/libcublas.dylib";
57-
libname[23] = DIGITS[major];
58-
libname[25] = DIGITS[minor];
58+
char libname[128];
59+
int n;
60+
#ifdef DEBUG
61+
fprintf(stderr, "Loading cuBLAS %d.%d.\n", major, minor);
62+
#endif
63+
n = snprintf(libname, sizeof(libname), "/Developer/NVIDIA/CUDA-%d.%d/lib/libcublas.dylib", major, minor);
64+
if (n < 0 || n >= sizeof(libname))
65+
return error_set(e, GA_SYS_ERROR, "snprintf");
5966
lib = ga_load_library(libname, e);
6067
}
6168
#else

src/loaders/libnvrtc.c

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include <stdlib.h>
1+
/* To be able to use snprintf with any compiler including MSVC2008. */
2+
#include <private_config.h>
23

34
#include "libcuda.h"
45
#include "libnvrtc.h"
@@ -27,22 +28,28 @@ int load_libnvrtc(int major, int minor, error *e) {
2728

2829
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
2930
{
30-
static const char DIGITS[] = "0123456789";
31-
char libname[] = "nvrtc64_??.dll";
32-
33-
libname[8] = DIGITS[major];
34-
libname[9] = DIGITS[minor];
31+
char libname[64];
32+
int n;
33+
#ifdef DEBUG
34+
fprintf(stderr, "Loading nvrtc %d.%d.\n", major, minor);
35+
#endif
36+
n = snprintf(libname, sizeof(libname), "nvrtc64_%d%d.dll", major, minor);
37+
if (n < 0 || n >= sizeof(libname))
38+
return error_set(e, GA_SYS_ERROR, "snprintf");
3539

3640
lib = ga_load_library(libname, e);
3741
}
3842
#else /* Unix */
3943
#ifdef __APPLE__
4044
{
41-
static const char DIGITS[] = "0123456789";
42-
/* Try the usual fullpath first */
43-
char libname[] = "/Developer/NVIDIA/CUDA-?.?/lib/libnvrtc.dylib";
44-
libname[23] = DIGITS[major];
45-
libname[25] = DIGITS[minor];
45+
char libname[128];
46+
int n;
47+
#ifdef DEBUG
48+
fprintf(stderr, "Loading nvrtc %d.%d.\n", major, minor);
49+
#endif
50+
n = snprintf(libname, sizeof(libname), "/Developer/NVIDIA/CUDA-%d.%d/lib/libnvrtc.dylib", major, minor);
51+
if (n < 0 || n >= sizeof(libname))
52+
return error_set(e, GA_SYS_ERROR, "snprintf");
4653
lib = ga_load_library(libname, e);
4754
}
4855
#else

0 commit comments

Comments
 (0)