Skip to content

Commit ed01642

Browse files
committed
Try workaround for newer Numba
1 parent 42610fe commit ed01642

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

ci/scripts/python_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,4 @@ export PYARROW_TEST_PARQUET_ENCRYPTION
6969
export PYARROW_TEST_S3
7070

7171
# Testing PyArrow
72-
pytest -r s ${PYTEST_ARGS} --pyargs pyarrow
72+
pytest -r s ${PYTEST_ARGS} --pyargs pyarrow.tests.test_cuda_numba_interop

python/pyarrow/_cuda.pyx

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,14 @@ cdef class CudaBuffer(Buffer):
460460
"""
461461
import ctypes
462462
from numba.cuda.cudadrv.driver import MemoryPointer
463-
return MemoryPointer(self.context.to_numba(),
464-
pointer=ctypes.c_void_p(self.address),
465-
size=self.size)
463+
try:
464+
return MemoryPointer(self.context.to_numba(),
465+
pointer=ctypes.c_void_p(self.address),
466+
size=self.size)
467+
except TypeError:
468+
# Newer Numba does not take a context argument anymore
469+
return MemoryPointer(pointer=ctypes.c_void_p(self.address),
470+
size=self.size)
466471

467472
cdef getitem(self, int64_t i):
468473
return self.copy_to_host(position=i, nbytes=1)[0]

0 commit comments

Comments
 (0)