Skip to content

Commit a575895

Browse files
authored
[NVIDIA][Launcher] Ensure device context is valid before calling getPointer (#5276)
1 parent 2c0b791 commit a575895

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

python/test/unit/runtime/test_driver.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import sys
2+
from concurrent.futures import ThreadPoolExecutor
3+
import torch
24

35
import triton
6+
import triton.language as tl
47

58

69
def test_is_lazy():
@@ -12,3 +15,27 @@ def test_is_lazy():
1215
assert triton.runtime.driver.active._obj is None
1316
utils = triton.runtime.driver.active.utils # noqa: F841
1417
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))
18+
19+
20+
def test_kernel_in_thread(device):
21+
# Test calling in a new thread sets a valid device context
22+
buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device)
23+
24+
@triton.jit
25+
def _kernel(P, BLOCK: tl.constexpr):
26+
pid = tl.program_id(0).to(tl.int64)
27+
offset = pid * BLOCK + tl.arange(0, BLOCK)
28+
29+
p = tl.load(P + offset)
30+
tl.store(P + offset, p)
31+
32+
def call_triton():
33+
N = buf.numel()
34+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), )
35+
_kernel[grid](buf, BLOCK=1024)
36+
getattr(torch, device).synchronize()
37+
38+
call_triton()
39+
with ThreadPoolExecutor(1) as pool:
40+
future = pool.submit(call_triton)
41+
future.result()

third_party/nvidia/backend/driver.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,6 @@ def format_of(ty):
212212
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
213213
void *params[] = {{ {', '.join(params)} }};
214214
if (gridX*gridY*gridZ > 0) {{
215-
CUcontext pctx;
216-
CUDA_CHECK(cuCtxGetCurrent(&pctx));
217-
if (!pctx) {{
218-
// Ensure device context.
219-
CUdevice device;
220-
CUDA_CHECK(cuDeviceGet(&device, 0));
221-
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
222-
CUDA_CHECK(cuCtxSetCurrent(pctx));
223-
}}
224215
if (num_ctas == 1) {{
225216
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
226217
}} else {{
@@ -288,6 +279,9 @@ def format_of(ty):
288279
PyErr_Format(PyExc_ValueError,
289280
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
290281
ptr_info.valid = false;
282+
}} else if (status != CUDA_SUCCESS) {{
283+
CUDA_CHECK(status); // Catch any other cuda API errors
284+
ptr_info.valid = false;
291285
}}
292286
ptr_info.dev_ptr = dev_ptr;
293287
Py_DECREF(ret); // Thanks ChatGPT!
@@ -344,7 +338,22 @@ def format_of(ty):
344338
return (CUtensorMap*)(ptr_as_uint);
345339
}}
346340
341+
static void ensureCudaContext() {{
342+
CUcontext pctx;
343+
CUDA_CHECK(cuCtxGetCurrent(&pctx));
344+
if (!pctx) {{
345+
// Ensure device context.
346+
CUdevice device;
347+
CUDA_CHECK(cuDeviceGet(&device, 0));
348+
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
349+
CUDA_CHECK(cuCtxSetCurrent(pctx));
350+
}}
351+
}}
352+
347353
static PyObject* launch(PyObject* self, PyObject* args) {{
354+
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
355+
ensureCudaContext();
356+
348357
int gridX, gridY, gridZ;
349358
uint64_t _stream;
350359
uint64_t _function;

0 commit comments

Comments
 (0)