Skip to content

Commit 85256a6

Browse files
authored
Ensure device context before launching kernel (#3731)
If a kernel is launched on a thread which has not initialized a CUDA context (as can happen in the linked issue), it will throw an error. A simple fix is to call `cudaFree(0)` to establish a device context. Fixes #3729
1 parent f637ea7 commit 85256a6

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

third_party/nvidia/backend/driver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,15 @@ def format_of(ty):
212212
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
213213
void *params[] = {{ {', '.join(params)} }};
214214
if (gridX*gridY*gridZ > 0) {{
215+
CUcontext pctx;
216+
CUDA_CHECK(cuCtxGetCurrent(&pctx));
217+
if (!pctx) {{
218+
// Ensure device context.
219+
CUdevice device;
220+
CUDA_CHECK(cuDeviceGet(&device, 0));
221+
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
222+
CUDA_CHECK(cuCtxSetCurrent(pctx));
223+
}}
215224
if (num_ctas == 1) {{
216225
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
217226
}} else {{

0 commit comments

Comments
 (0)