Skip to content

Commit e5ef7e2

Browse files
authored
Fix CUDA context leak in with cuda.gpus[N]: context manager (#855)
## Summary - **Fix**: `Device.get_primary_context()` calls `self._dev.set_current()` on first invocation, which leaves the primary context active on the thread. The caller `_activate_context_for` then calls `push()`, saving that already-active context onto the stack. On exit, `pop()` restores it — so the context remains current after the `with` block (a leak). The fix pops the context left by `set_current()` immediately after obtaining the handle, so `get_primary_context()` upholds its documented contract: *"Note: it is not pushed to the CPU thread."* - **Tests**: Adds two regression tests in `test_context_stack.py` — one verifying no context remains after `with cuda.gpus[0]: pass` on a clean stack, and another verifying the previous context is properly restored. Made with [Cursor](https://cursor.com) --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent 6803734 commit e5ef7e2

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,18 @@ def get_primary_context(self):
534534
f"{self} has compute capability < {MIN_REQUIRED_CC}"
535535
)
536536

537+
prev = get_cuda_native_handle(driver.cuCtxGetCurrent())
537538
self._dev.set_current()
538539
if CUDA_CORE_GT_0_6:
539540
ctx_handle = self._dev.context.handle
540541
else:
541542
ctx_handle = self._dev.context._handle
543+
# set_current() may push a context onto the thread's stack. Undo
544+
# that so callers (_activate_context_for) can push/pop symmetrically.
545+
# Only pop when set_current() actually changed the current context;
546+
# it is a no-op when a context for this device is already active.
547+
if get_cuda_native_handle(driver.cuCtxGetCurrent()) != prev:
548+
driver.cuCtxPopCurrent()
542549
self.primary_context = ctx = Context(
543550
weakref.proxy(self),
544551
ctx_handle,

numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,51 @@ def switch_gpu():
7777
self.assertEqual(int(devid), 1)
7878

7979

80+
@skip_on_cudasim("CUDA HW required")
81+
class TestContextLeak(CUDATestCase):
82+
"""Regression tests for context leaks from the gpu context manager."""
83+
84+
def test_gpus_context_manager_does_not_leak(self):
85+
# Regression test: ``with cuda.gpus[N]`` must not leave a CUDA
86+
# context on the thread after the block exits.
87+
the_driver = driver.driver
88+
89+
# Drain any pre-existing contexts from the stack.
90+
while the_driver.pop_active_context() is not None:
91+
pass
92+
93+
with cuda.gpus[0]:
94+
pass
95+
96+
# After exiting the context manager the current context must be null.
97+
with the_driver.get_active_context() as ac:
98+
self.assertIsNone(
99+
ac.context_handle,
100+
"CUDA context leaked after exiting cuda.gpus context manager",
101+
)
102+
103+
def test_gpus_context_manager_restores_previous_context(self):
104+
# If a context is already active before entering the context manager,
105+
# it must be restored on exit.
106+
the_driver = driver.driver
107+
108+
# Ensure device-0 context exists and is pushed.
109+
outer_ctx = cuda.current_context()
110+
outer_handle = int(outer_ctx.handle)
111+
112+
with cuda.gpus[0]:
113+
pass
114+
115+
with the_driver.get_active_context() as ac:
116+
self.assertIsNotNone(ac.context_handle)
117+
self.assertEqual(
118+
int(ac.context_handle),
119+
outer_handle,
120+
"Previous context was not restored after exiting "
121+
"cuda.gpus context manager",
122+
)
123+
124+
80125
@skip_on_cudasim("CUDA HW required")
81126
class Test3rdPartyContext(CUDATestCase):
82127
def tearDown(self):

0 commit comments

Comments
 (0)