Skip to content

Commit cc1d4c5

Browse files
authored
Implement getLaunchKernelExHandle and defineGetFunctionHandle for Windows (#2745)
Part of #2478 These are quite stable changes, we can merge it without CI on Windows. @gshimansky if you don't mind. Once we test them on Nvidia UT I might try to upstream them. Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3b53ca9 commit cc1d4c5

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

third_party/nvidia/backend/driver.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
#include "cuda.h"
2+
#ifdef WIN32
3+
#define WIN32_LEAN_AND_MEAN
4+
#define NOMINMAX
5+
#include <windows.h>
6+
#else
27
#include <dlfcn.h>
8+
#endif
39
#include <stdbool.h>
410
#define PY_SSIZE_T_CLEAN
511
#include <Python.h>
@@ -161,6 +167,27 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
161167
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
162168
CUtensorMapFloatOOBfill oobFill);
163169

170+
#ifdef WIN32
171+
#define defineGetFunctionHandle(name, symbolName) \
172+
static symbolName##_t name() { \
173+
/* Open the shared library */ \
174+
HMODULE handle = LoadLibraryA("nvcuda.dll"); \
175+
if (!handle) { \
176+
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \
177+
return NULL; \
178+
} \
179+
symbolName##_t funcHandle = \
180+
(symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \
181+
/* Check for errors */ \
182+
long err = GetLastError(); \
183+
if (err) { \
184+
PyErr_SetString(PyExc_RuntimeError, \
185+
"Failed to retrieve " #symbolName " from nvcuda.dll"); \
186+
return NULL; \
187+
} \
188+
return funcHandle; \
189+
}
190+
#else
164191
#define defineGetFunctionHandle(name, symbolName) \
165192
static symbolName##_t name() { \
166193
/* Open the shared library */ \
@@ -182,6 +209,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
182209
} \
183210
return funcHandle; \
184211
}
212+
#endif
185213

186214
defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
187215
cuOccupancyMaxActiveClusters);

third_party/nvidia/backend/driver.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,12 @@ def format_of(ty):
167167
#include \"cuda.h\"
168168
#include <stdbool.h>
169169
#include <Python.h>
170+
#ifndef _WIN32
170171
#include <dlfcn.h>
172+
#else
173+
#define WIN32_LEAN_AND_MEAN
174+
#include <windows.h>
175+
#endif
171176
172177
static inline void gpuAssert(CUresult code, const char *file, int line)
173178
{{
@@ -190,6 +195,7 @@ def format_of(ty):
190195
191196
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
192197
198+
#ifndef _WIN32
193199
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
194200
// Open the shared library
195201
void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
@@ -208,6 +214,25 @@ def format_of(ty):
208214
}}
209215
return cuLaunchKernelExHandle;
210216
}}
217+
#else
218+
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
219+
// Open the shared library
220+
HMODULE handle = LoadLibraryA("nvcuda.dll");
221+
if (!handle) {{
222+
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
223+
return NULL;
224+
}}
225+
cuLaunchKernelEx_t cuLaunchKernelExHandle =
226+
(cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx");
227+
// Check for errors
228+
long error = GetLastError();
229+
if (error) {{
230+
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll");
231+
return NULL;
232+
}}
233+
return cuLaunchKernelExHandle;
234+
}}
235+
#endif
211236
212237
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
213238
void *params[] = {{ {', '.join(params)} }};

0 commit comments

Comments
 (0)