Skip to content

Commit 34400d2

Browse files
Fix indexing overflow issue for blockwise quantization with large tensor sizes (#1784)
1 parent c3b8de2 commit 34400d2

File tree

4 files changed

+83
-21
lines changed

4 files changed

+83
-21
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _(
326326
get_ptr(absmax),
327327
get_ptr(out),
328328
ct.c_int32(blocksize),
329-
ct.c_int(n),
329+
ct.c_int32(n),
330330
)
331331

332332
if A.dtype == torch.bfloat16:
@@ -403,7 +403,7 @@ def _dequantize_4bit_impl(
403403
get_ptr(absmax),
404404
get_ptr(out),
405405
ct.c_int(blocksize),
406-
ct.c_int(out.numel()),
406+
ct.c_int32(out.numel()),
407407
_get_tensor_stream(A),
408408
)
409409

csrc/kernels.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,14 +328,16 @@ __global__ void kQuantizeBlockwise(
328328
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
329329
const int rand_offset, const int n
330330
) {
331-
const int n_full = gridDim.x * BLOCK_SIZE;
331+
// This can overflow, so we clamp to INT32_MAX. We won't have more elements than this.
332+
const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX);
333+
334+
const int base_idx = blockIdx.x * BLOCK_SIZE;
332335
int valid_items = 0;
333-
const int base_idx = (blockIdx.x * BLOCK_SIZE);
334336

335337
T vals[NUM_PER_TH];
336338
float rand_vals[NUM_PER_TH];
337339
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH];
338-
// float local_abs_max = -FLT_MAX;
340+
339341
float local_abs_max = 0.0f;
340342
int local_rand_idx = 0;
341343

@@ -358,8 +360,8 @@ __global__ void kQuantizeBlockwise(
358360
for (int i = threadIdx.x; i < 256; i += blockDim.x)
359361
smem_code[i] = code[i];
360362

361-
for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
362-
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
363+
for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
364+
valid_items = min(BLOCK_SIZE, static_cast<int>(n - i));
363365
local_abs_max = -FLT_MAX;
364366

365367
__syncthreads();
@@ -442,7 +444,8 @@ __global__ void
442444

443445
for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
444446
if (DATA_TYPE > 0) {
445-
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
447+
// Cast n to int64_t to avoid overflow for large n
448+
valid_items_load = min(TILE_SIZE, static_cast<int>((static_cast<int64_t>(n) + 1) / 2) - i);
446449
valid_items_store = min(TILE_SIZE * 2, n - i * 2);
447450
} else {
448451
valid_items_load = min(TILE_SIZE, n - i);

csrc/ops.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,17 @@ template <typename T, int DATA_TYPE>
6161
void dequantizeBlockwise(
6262
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream
6363
) {
64-
// printf("stream==%d\n",stream);
65-
int num_blocks = n / blocksize;
66-
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
67-
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
64+
constexpr int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
65+
66+
// Upcast to int64 to avoid overflow for large n
67+
int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size;
68+
6869
if (DATA_TYPE > 0)
6970
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
70-
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
71+
<<<grid_blocks, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n);
7172
else
7273
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>
73-
<<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
74+
<<<grid_blocks, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
7475

7576
CUDA_CHECK_RETURN(cudaPeekAtLastError());
7677
}

tests/test_functional.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,34 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
151151
assert relerr < 0.012
152152
assert A2.dtype == dtype
153153

154+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
155+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
156+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
157+
@pytest.mark.parametrize("blocksize", [256], ids=id_formatter("blocksize"))
158+
def test_dynamic_blockwise_quantization_large(self, device, dtype, blocksize):
159+
"""
160+
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
161+
- On CUDA/XPU/ROCm, the maximum number of elements is limited to 2**31 - 1 due to int32 indexing in C++ kernels.
162+
- On CPU, there is a significantly higher memory overhead for the quantization, so we skip this test.
163+
- Verification of the accuracy for dequantization has too high memory overhead for this test.
164+
"""
165+
if device not in ["cuda", "xpu"]:
166+
pytest.skip("This test is only for CUDA and XPU devices due to memory constraints.")
167+
168+
data = torch.randn(2**31 - 1, device=device, dtype=dtype)
169+
q_data, q_stats = F.quantize_blockwise(data, blocksize=blocksize)
170+
171+
assert q_data is not None
172+
assert q_data.dtype == torch.uint8
173+
assert q_data.numel() == data.numel()
174+
175+
# Dequant
176+
del data
177+
dq = F.dequantize_blockwise(q_data, q_stats)
178+
179+
assert dq.dtype == dtype
180+
assert dq.numel() == q_data.numel()
181+
154182
@pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required")
155183
@pytest.mark.parametrize("hidden", [128])
156184
@pytest.mark.parametrize("blocksize", [4096, 16384])
@@ -1118,18 +1146,17 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11181146
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11191147
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11201148
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
1149+
del qa, SA
1150+
1151+
assert A2.dtype == dtype
11211152

11221153
err = (A1 - A2).abs().float()
1154+
del A2
1155+
11231156
relerr = (err / (A1.abs().float() + 1e-8)).mean()
11241157
err = err.mean()
11251158

1126-
assert A2.dtype == dtype
1127-
1128-
# With larger block sizes, we can expect this to blow up.
1129-
# At blocksize>=1024, don't even bother looking at relerr.
1130-
#
1131-
# Actually, the above is not true anymore after fixing the integer packing bug.
1132-
# The following values were taken from averaging 1k samples per test configuration after fixing the bug.
1159+
# The following values were taken from averaging 1k samples per test configuration.
11331160
error_dict = dict()
11341161
error_dict["fp4"] = dict()
11351162
error_dict["nf4"] = dict()
@@ -1213,6 +1240,37 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
12131240
assert err.item() < 0.11
12141241
assert relerr.item() < 0.28
12151242

1243+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
1244+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
1245+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
1246+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1247+
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
1248+
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
1249+
"""
1250+
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
1251+
- On CUDA/XPU/ROCm, the maximum number of elements is limited to 2**31 - 1 due to int32 indexing in C++ kernels.
1252+
- On CUDA, this test requires ~10GiB of memory for fp32
1253+
- On CPU, there is a significantly higher memory overhead for the quantization, so we skip this test.
1254+
- Verification of the accuracy for dequantization has too high memory overhead for this test.
1255+
"""
1256+
1257+
if device not in ["cuda", "xpu"]:
1258+
pytest.skip("This test is only for CUDA and XPU devices due to memory constraints.")
1259+
1260+
A1 = torch.randn(2**31 - 1, device=device, dtype=dtype)
1261+
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1262+
1263+
assert qa is not None
1264+
assert qa.dtype == torch.uint8
1265+
assert qa.numel() == (2**31 - 1 + 1) // 2 # each byte holds 2 quantized values
1266+
1267+
# Dequant
1268+
del A1
1269+
dq = F.dequantize_4bit(qa, SA)
1270+
1271+
assert dq.dtype == dtype
1272+
assert dq.numel() == 2**31 - 1
1273+
12161274
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
12171275
@pytest.mark.parametrize("quant_type", ["nf4"])
12181276
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")

0 commit comments

Comments
 (0)