diff --git a/bitsandbytes/backends/triton/kernels_4bit.py b/bitsandbytes/backends/triton/kernels_4bit.py index 0e94f49e8..bdd59fad2 100644 --- a/bitsandbytes/backends/triton/kernels_4bit.py +++ b/bitsandbytes/backends/triton/kernels_4bit.py @@ -66,7 +66,8 @@ def quantize_fp4_blockwise_kernel( packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - out_mask = out_offsets < n_elements // 2 + # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n + out_mask = out_offsets < (n_elements - n_elements // 2) tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) @@ -148,7 +149,8 @@ def quantize_nf4_blockwise_kernel( packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - out_mask = out_offsets < n_elements // 2 + # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n + out_mask = out_offsets < (n_elements - n_elements // 2) tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) @@ -330,7 +332,14 @@ def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.const # ) @triton.jit def dequant_4bit_kernel( - a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr + a_ptr, + c_ptr, + quant_ptr, + absmax_ptr, + num_paired_elements, + num_output_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. block_start = pid * SPLIT_SIZE @@ -350,7 +359,7 @@ def dequant_4bit_kernel( out_block_start = pid * SPLIT_SIZE * 2 offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) - mask = offs < num_paired_elements * 2 + mask = offs < num_output_elements tl.store(c_ptr + offs, out_dq, mask) @@ -367,7 +376,13 @@ def dequant_4bit_kernel( # ) @triton.jit def dequant_fp4_kernel( - a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr + a_ptr, + c_ptr, + absmax_ptr, + num_paired_elements, + num_output_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. block_start = pid * SPLIT_SIZE @@ -386,7 +401,7 @@ def dequant_fp4_kernel( out_block_start = pid * SPLIT_SIZE * 2 offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) - mask = offs < num_paired_elements * 2 + mask = offs < num_output_elements tl.store(c_ptr + offs, out_dq, mask) @@ -403,7 +418,13 @@ def dequant_fp4_kernel( # ) @triton.jit def dequant_nf4_kernel( - a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr + a_ptr, + c_ptr, + absmax_ptr, + num_paired_elements, + num_output_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. block_start = pid * SPLIT_SIZE @@ -422,7 +443,7 @@ def dequant_nf4_kernel( out_block_start = pid * SPLIT_SIZE * 2 offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) - mask = offs < num_paired_elements * 2 + mask = offs < num_output_elements tl.store(c_ptr + offs, out_dq, mask) @@ -439,15 +460,16 @@ def dequantize_4bit_impl( # Elements are in uint8 format, so interleaved # so total amount of data is 2 * elem_count number_of_paired_elements = A.numel() + num_output_elements = out.numel() # we assume that split_size > quant_blocksize SPLIT_SIZE = 256 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), ) grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) if quant_type == "fp4": - dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE) else: - dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE) def dequantize_4bit_impl_passing_code( @@ -459,12 +481,15 @@ def dequantize_4bit_impl_passing_code( out: torch.Tensor, ) -> None: number_of_paired_elements = A.numel() + num_output_elements = out.numel() # we assume that split_size > quant_blocksize SPLIT_SIZE = 256 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), ) grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) - dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + dequant_4bit_kernel[grid]( + A, out, code, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE + ) ######################### Fallback dequantization functions ######################### diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 66bff3c94..3a16961fa 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -82,7 +82,8 @@ def quantize_4bit( blocks = -(n // -(blocksize * 2)) absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) - out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) + # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n + out = torch.empty((n - n // 2, 1), device=A.device, dtype=torch.uint8) with torch_accelerator_module.device(A.device): kernels_4bit.quantize_4bit_blockwise_triton( diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index a0620dc4b..e565f1d34 100644 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -38,7 +38,7 @@ def _dequantize_4bit_impl( get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), - ct.c_int(out.numel()), + ct.c_int64(out.numel()), _get_tensor_stream(A), ) if dtype == torch.bfloat16: diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 340f06145..868cb3153 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -303,57 +303,57 @@ void spmm_coo_very_sparse_naive_int8( #if BUILD_XPU void dequantizeBlockwise_fp16( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_fp4( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_nf4( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_fp4( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_nf4( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); @@ -728,57 +728,57 @@ void cgemm_4bit_inference_naive_fp32( #if BUILD_XPU void cdequantize_blockwise_fp16_fp4( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp16( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp16_nf4( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_fp4( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_nf4( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_fp4( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_nf4( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ) { dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp index 8ee8add98..270251e18 100644 --- a/csrc/xpu_kernels.cpp +++ b/csrc/xpu_kernels.cpp @@ -95,20 +95,20 @@ inline float dDequantizeNF4(unsigned char val) { template SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::nd_item<1> item) const { - const int base_idx = item.get_group(0) * TILE_SIZE; - size_t local_idx = item.get_local_id(0) * NUM_PER_TH; + const int64_t base_idx = static_cast(item.get_group(0)) * TILE_SIZE; + int64_t local_idx = static_cast(item.get_local_id(0)) * NUM_PER_TH; float local_abs_max = -FLT_MAX; - int local_load_idx = 0; - int local_store_idx = 0; + int64_t local_load_idx = 0; + int64_t local_store_idx = 0; uint8_t qvals[NUM_PER_TH]; T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; if (DATA_TYPE > 0) { - local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); - local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); + local_load_idx = sycl::min(static_cast(TILE_SIZE), (n + 1) / 2 - base_idx); + local_store_idx = sycl::min(static_cast(TILE_SIZE * 2), n - base_idx * 2); } else { - local_load_idx = sycl::min(TILE_SIZE, n - base_idx); + local_load_idx = sycl::min(static_cast(TILE_SIZE), n - base_idx); local_store_idx = local_load_idx; } diff --git a/csrc/xpu_kernels.h b/csrc/xpu_kernels.h index caa7e6716..1e08f907c 100644 --- a/csrc/xpu_kernels.h +++ b/csrc/xpu_kernels.h @@ -8,7 +8,7 @@ template class kDequa public: SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; - kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_) + kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int64_t n_) : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {} private: @@ -17,7 +17,7 @@ template class kDequa float* absmax; T* out; const int blocksize; - const int n; + const int64_t n; }; template class kgemv_4bit_inference { diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp index 48c986fc4..6d9d89b15 100644 --- a/csrc/xpu_ops.cpp +++ b/csrc/xpu_ops.cpp @@ -3,7 +3,7 @@ template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int64_t n, sycl::queue* stream ) { auto& queue = *stream; const int workgroup_size = 128; @@ -55,35 +55,35 @@ void gemv_4bit_inference( //============================================================== template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n, sycl::queue* stream ); diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h index a5ea80f97..ffe7c9242 100644 --- a/csrc/xpu_ops.h +++ b/csrc/xpu_ops.h @@ -30,7 +30,7 @@ static inline void sycl_comp_kernel_submit(sycl::nd_range range, sycl::queu template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int64_t n, sycl::queue* stream ); template void gemv_4bit_inference(