-
Notifications
You must be signed in to change notification settings - Fork 57
Description
Problem
Numba-CUDA makes assumptions about integer types in multiple places, including:
- Literals
- Internal structures (e.g., array shapes and strides)
- NBEP (Numba Enhancement Proposal): https://numba.readthedocs.io/en/stable/proposals/integer-typing.html
This behavior originates from the original Numba implementation, where execution on CPU is unaffected by aggressive upcasting. However, on GPUs, using int64 everywhere can introduce significant performance penalties for certain kernel types. Specifically, compute-bound kernels where register pressure determines occupancy and can cause severe performance degradation.
Possible Solutions
A) NBEP
- Remove NBEP entirely
- Make NBEP an optional, user-configurable flag
B) Internal Structures
- Introduce per compilation flags that allow users to specify types for internal structures. Caveat: This approach would prevent mixing functions compiled with different flags within a single kernel.
- Introduce a configurable integer type per array
C) Literals
- Use minimal storage width for literals
Environment details (please complete the following information):
- Environment location: wsl local machine
- Method of numba-cuda install: pip install
- Python env:
numba-cuda==0.27.0
numba==0.63.1
Python 3.11.14
Additional context
Example code demonstrating the Issue
from numba import cuda, int32
import numpy as np
I32_0 = int32(0)
I32_1 = int32(1)
I32_2 = int32(2)
I32_3 = int32(3)
I32_4 = int32(4)
I32_5 = int32(5)
I32_6 = int32(6)
I32_7 = int32(7)
I32_1000 = int32(1000)
I32_1000000 = int32(1000000)
I32_19 = int32(19)
I32_23 = int32(23)
I32_31 = int32(31)
I32_11 = int32(11)
I32_13 = int32(13)
I32_17 = int32(17)
def time_numba_prep_args(kernel, grid_dim, block_dim, shared_memory_size, ncycles, prep_args):
stream = cuda.stream()
start, stop = cuda.event(), cuda.event()
cuda.synchronize()
kernel[grid_dim, block_dim, stream, shared_memory_size](*prep_args(stream))
stream.synchronize()
total_time_ms = 0.0
for _ in range(ncycles):
prepared_args = prep_args(stream)
stream.synchronize()
start.record(stream)
kernel[grid_dim, block_dim, stream, shared_memory_size](*prepared_args)
stop.record(stream)
stream.synchronize()
total_time_ms += cuda.event_elapsed_time(start, stop)
return total_time_ms / ncycles
def main():
N_REPS = 20
M = 1024 * 8
WARP_SIZE = 32
TILE_SIZE = WARP_SIZE // 2
TILE_SIZE_NUMBA = int32(TILE_SIZE)
assert M % WARP_SIZE == 0
SHAPE = (M, M)
N_TILE_PER_DIM = M // TILE_SIZE
N_TILE = N_TILE_PER_DIM ** 2
N_THREADS = (TILE_SIZE, TILE_SIZE)
GRID_DIM = (N_TILE_PER_DIM, N_TILE_PER_DIM)
@cuda.jit
def MAIN_KERNEL(matrix):
x = int32(cuda.blockIdx.x) * TILE_SIZE_NUMBA + int32(cuda.threadIdx.x)
y = int32(cuda.blockIdx.y) * TILE_SIZE_NUMBA + int32(cuda.threadIdx.y)
if x < int32(matrix.shape[0]) and y < int32(matrix.shape[1]):
val_xp1 = I32_0
val_xm1 = I32_0
if x + I32_1 < int32(matrix.shape[0]):
val_xp1 = matrix[x + I32_1, y]
val_xm1 = -matrix[x + I32_1, y]
result = int32(matrix[x, y])
val_xp1 = int32(val_xp1)
val_xm1 = int32(val_xm1)
for _ in range(I32_0, I32_1000, I32_1):
temp = (result * I32_3 + val_xp1 * I32_2 + val_xm1) % I32_1000000
temp = (temp * I32_7 - result * I32_5 + val_xp1 * I32_11) % I32_1000000
temp = (temp * I32_13 + val_xm1 * I32_17 - result * I32_19) % I32_1000000
result = (temp + result * I32_23) % I32_1000000
result = result ^ (result + I32_1)
result = result + (result * I32_31) % I32_1000
matrix[x, y] = result
print(f"Problem size: M: {M}, WARP_SIZE: {WARP_SIZE}, N_TILE: {N_TILE}, GRID_DIM: {GRID_DIM}, N_THREADS: {N_THREADS}")
matrix = np.ones(SHAPE, dtype=np.int32)
matrix_d = cuda.to_device(matrix)
MAIN_KERNEL[GRID_DIM, N_THREADS](matrix_d)
cuda.synchronize()
regs = list(MAIN_KERNEL.get_regs_per_thread().values())[0]
print(f"Regs per block: {regs * N_THREADS[0] * N_THREADS[1]}")
print(f"Max blocks running: {64*1024//(regs*N_THREADS[0]*N_THREADS[1])}")
time = time_numba_prep_args(MAIN_KERNEL, GRID_DIM, N_THREADS, 0, N_REPS, lambda _: (matrix_d,))
print(f"Time: {time} ms")
if __name__ == "__main__":
main()
With NBEP enabled (default behavior):
Problem size: M: 8192, WARP_SIZE: 32, N_TILE: 262144, GRID_DIM: (512, 512), N_THREADS: (16, 16)
Regs per block: 6656
Max blocks running: 9
Time: 902.8884979248047 ms
With integer upcasting removed (quick patch applied):
Problem size: M: 8192, WARP_SIZE: 32, N_TILE: 262144, GRID_DIM: (512, 512), N_THREADS: (16, 16)
Regs per block: 3584
Max blocks running: 18
Time: 244.8684539794922 ms
Patch applied: (NBEP removed + enforced 4 byte integers)
--- a/numba_cuda/numba/cuda/typing/builtins.py
+++ b/numba_cuda/numba/cuda/typing/builtins.py
@@ -164,7 +164,7 @@ class PairSecond(AbstractTemplate):
def choose_result_bitwidth(*inputs):
- return max(types.intp.bitwidth, *(tp.bitwidth for tp in inputs))
+ return max(0, *(tp.bitwidth for tp in inputs))
def choose_result_int(*inputs):
@@ -179,8 +179,8 @@ def choose_result_int(*inputs):
# The "machine" integer types to take into consideration for operator typing
# (according to the integer typing NBEP)
-machine_ints = sorted(set((types.intp, types.int64))) + sorted(
- set((types.uintp, types.uint64))
+machine_ints = sorted(set((types.int8,types.int16, types.int32))) + sorted(
+ set((types.uint8, types.uint16, types.uint32))
)Disabling NBEP-style upcasting results in ~3.7× faster execution due to reduced register pressure (3584 vs. 6656 registers per block) and int32 operations