From 6be1412307c517243675c2e2e5245d87bccb6735 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 15:02:33 +0000 Subject: [PATCH 01/78] add template to support more dtypes Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 66 ++++++++++++++++++++++++++++++++++++---- csrc/cpu_ops.h | 11 ++++++- csrc/pythonInterface.cpp | 14 ++++++++- 3 files changed, 83 insertions(+), 8 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 5c2bc6332..0aadf596e 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,15 +1,26 @@ #include #include +#include #include using namespace BinSearch; -void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) { - for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; - long long block_end = block_idx + valid_items; - for (long long i = block_idx; i < block_end; i++) - out[i] = code[A[i]] * absmax[block_idx / blocksize]; +template +void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { + switch (DATA_TYPE) { + case General8bit: + #pragma omp parallel for + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = static_cast(code[A[i]] * absmax[block_idx / blocksize]); + } + case NF4: + return; + case FP4: + return; + break; } } @@ -59,3 +70,46 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long threads[i].join(); } } + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); + +// template void gemv_4bit_inference( +// int m, int n, int k, at::Half* A, unsigned char* B, float* absmax, float* datatype, at::Half* out, +// int lda, int ldb, int ldc, int blocksize); + +// template void gemv_4bit_inference( +// int m, int n, int k, at::BFloat16* A, unsigned char* B, float* absmax, float* datatype, at::BFloat16* out, +// int lda, int ldb, int ldc, int blocksize); + +// template void gemv_4bit_inference( +// int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, +// int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 3c10e6d13..72f759497 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,8 +3,17 @@ #include #include +#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); -void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +template +void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 28121240f..8bf32417f 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,6 +845,18 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n ) { - dequantize_cpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_bf16( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp16( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } } From 252ac0f84af0f0c425f7b1075c2ee63b60b1ed6e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 15:04:45 +0000 Subject: [PATCH 02/78] update cmake list Signed-off-by: jiqing-feng --- CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c133e09f..952be8a04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,6 +243,7 @@ elseif(BUILD_XPU) set(CMAKE_CXX_COMPILER icx) endif() else() + find_package(Torch REQUIRED) string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) endif() @@ -317,7 +318,9 @@ if(BUILD_XPU) set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) - +else() + target_link_options(bitsandbytes PRIVATE ${TORCH_LIBRARIES}) + include_directories(${TORCH_INCLUDE_DIRS}) endif() if(WIN32) From f98c9e5d98ffc2f6332665971bbd6cf08c8a3003 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 15:14:48 +0000 Subject: [PATCH 03/78] fix typo Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 8bf32417f..eaaa953f8 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -849,13 +849,13 @@ void cdequantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } From 902bf359f51705ffa4d7f2ba527caaf5be77782f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 18:17:23 +0000 Subject: [PATCH 04/78] fix compile cpu Signed-off-by: jiqing-feng --- CMakeLists.txt | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 952be8a04..808ade86f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,9 +78,16 @@ else() set(BUILD_HIP OFF) set(BUILD_MPS OFF) set(BUILD_XPU OFF) + set(BUILD_CPU ON) endif() +if (BUILD_CPU) + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + find_package(Torch REQUIRED) +endif() + if(BUILD_CUDA) # NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+. # Workaround: use --allow-unsupported-compiler @@ -242,10 +249,13 @@ elseif(BUILD_XPU) if(WIN32) set(CMAKE_CXX_COMPILER icx) endif() -else() +elseif(BUILD_CPU) find_package(Torch REQUIRED) string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) endif() @@ -263,6 +273,9 @@ add_library(bitsandbytes SHARED ${SRC_FILES}) target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) +if (BUILD_CPU) + target_link_libraries(bitsandbytes "${TORCH_LIBRARIES}") +endif() if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) @@ -318,9 +331,7 @@ if(BUILD_XPU) set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) -else() - target_link_options(bitsandbytes PRIVATE ${TORCH_LIBRARIES}) - include_directories(${TORCH_INCLUDE_DIRS}) + endif() if(WIN32) From fef8459f52c21b38f50ad7250f9ef9f4013b7e31 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 29 Oct 2025 09:36:04 +0000 Subject: [PATCH 05/78] make different dtype works Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index e295cc2a3..a69d89f3a 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -76,10 +76,8 @@ def _( torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - # Only FP32 has c++ kernrl + out = torch.empty_like(A, dtype=dtype) if dtype == torch.float32: - out = torch.empty_like(A, dtype=dtype) - lib.cdequantize_blockwise_cpu_fp32( get_ptr(code), get_ptr(A), @@ -88,6 +86,24 @@ def _( ct.c_longlong(blocksize), ct.c_longlong(A.numel()), ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) else: out = code[A.reshape(-1).int()] blocks = out.shape[-1] // blocksize From 55cbaa0d0809711b1df1b5bb44da49312bc535d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 29 Oct 2025 09:46:18 +0000 Subject: [PATCH 06/78] use bf16 on CPU Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index ece18caa3..1cc24bb46 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -432,6 +432,9 @@ def matmul_4bit( bias: Optional[torch.Tensor] = None, ): assert quant_state is not None + # Change dtype to bfloat16 on CPU + if A.device.type == "cpu" and quant_state.dtype == torch.float32: + quant_state.dtype = torch.bfloat16 if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: From bbef95b3bab9168879e99a891ceecc924c140a14 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 29 Oct 2025 09:52:52 +0000 Subject: [PATCH 07/78] fix state2 dtype Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 1cc24bb46..0aba814c1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -433,8 +433,11 @@ def matmul_4bit( ): assert quant_state is not None # Change dtype to bfloat16 on CPU - if A.device.type == "cpu" and quant_state.dtype == torch.float32: - quant_state.dtype = torch.bfloat16 + if A.device.type == "cpu": + if quant_state.dtype == torch.float32: + quant_state.dtype = torch.bfloat16 + if hasattr(quant_state, "state2") and quant_state.state2.dtype == torch.float32: + quant_state.state2.dtype = torch.bfloat16 if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: From e8425135de33f01d75cf1479ea14a387449fd268 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:27:34 +0000 Subject: [PATCH 08/78] remove torch Signed-off-by: jiqing-feng --- CMakeLists.txt | 8 ++---- csrc/cpu_ops.cpp | 63 ++++++++++++++++++++++++++++-------------------- csrc/cpu_ops.h | 10 ++++++++ 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 808ade86f..c5abfca78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,7 @@ endif() if (BUILD_CPU) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) - find_package(Torch REQUIRED) + find_package(OpenMP) endif() if(BUILD_CUDA) @@ -249,10 +249,6 @@ elseif(BUILD_XPU) if(WIN32) set(CMAKE_CXX_COMPILER icx) endif() -elseif(BUILD_CPU) - find_package(Torch REQUIRED) - string(APPEND BNB_OUTPUT_NAME "_cpu") - set(GPU_SOURCES) else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -274,7 +270,7 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) if (BUILD_CPU) - target_link_libraries(bitsandbytes "${TORCH_LIBRARIES}") + target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) endif() if(BUILD_CUDA) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 0aadf596e..ec07a593b 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,25 +5,36 @@ using namespace BinSearch; + template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { - switch (DATA_TYPE) { - case General8bit: +void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, + long long blocksize, long long n) { + if (DATA_TYPE > 0) { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long block_end = block_idx + valid_items; - for (long long i = block_idx; i < block_end; i++) - out[i] = static_cast(code[A[i]] * absmax[block_idx / blocksize]); + float scale = absmax[block_idx / blocksize]; + for (long long i = block_idx; i < block_end; i++) { + float v = code[A[i]] * scale; + if constexpr (std::is_same::value) { + out[i] = float_to_bf16(v); + } else { + out[i] = static_cast(v); + } + } } - case NF4: - return; - case FP4: - return; - break; + } else { + // 4bit path + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, n); } } +template +void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { + return; +} + void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below @@ -84,30 +95,30 @@ template void dequantizeBlockwiseCpu( template void dequantizeBlockwiseCpu( float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); -// template void gemv_4bit_inference( -// int m, int n, int k, at::Half* A, unsigned char* B, float* absmax, float* datatype, at::Half* out, +// template void gemv_4bit_inference( +// int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, // int lda, int ldb, int ldc, int blocksize); -// template void gemv_4bit_inference( -// int m, int n, int k, at::BFloat16* A, unsigned char* B, float* absmax, float* datatype, at::BFloat16* out, +// template void gemv_4bit_inference( +// int m, int n, int k, bf16_t* A, unsigned char* B, float* absmax, float* datatype, bf16_t* out, // int lda, int ldb, int ldc, int blocksize); // template void gemv_4bit_inference( diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 72f759497..37026939a 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -4,6 +4,7 @@ #include #include #include +#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -13,7 +14,16 @@ typedef enum DataType_t { NF4 = 2, } DataType_t; +using fp16_t = _Float16; + +struct bf16_t { + uint16_t v; +}; + template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); +template +void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) + #endif From d4473fa9314dfbb11135f7566b7a4a09acf21750 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:30:19 +0000 Subject: [PATCH 09/78] rm torch Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 2 -- csrc/pythonInterface.cpp | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 37026939a..19e4cf909 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,8 +3,6 @@ #include #include -#include -#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index eaaa953f8..5056ccf0c 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -849,14 +849,14 @@ void cdequantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } } From dea8dd6377d788aaccabccc350d343c3a41d114f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:31:55 +0000 Subject: [PATCH 10/78] enable float to bf16 Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 19e4cf909..85c13a334 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,6 +3,9 @@ #include #include +#include +#include +#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -18,6 +21,13 @@ struct bf16_t { uint16_t v; }; +static inline bf16_t float_to_bf16(float x) { + uint32_t bits; + std::memcpy(&bits, &x, 4); + uint32_t r = bits + 0x7FFF + ((bits >> 16) & 1); + return bf16_t{static_cast(r >> 16)}; +} + template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); From e9bb4fe15ae0dbb73241d9e49a80ddacab1ecdd1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:33:39 +0000 Subject: [PATCH 11/78] rm dequantizeBlockwise4bitCpu Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 9 ++------- csrc/cpu_ops.h | 3 --- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ec07a593b..205b6824d 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -25,16 +25,11 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out } } } else { - // 4bit path - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, n); + // TODO: enable nf4 and fp4 + return; } } -template -void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { - return; -} - void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 85c13a334..77791b3e6 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -31,7 +31,4 @@ static inline bf16_t float_to_bf16(float x) { template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); -template -void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) - #endif From cdc8d5e02606bb740da660f309be00d543455eaa Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:43:03 +0000 Subject: [PATCH 12/78] fix check Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 205b6824d..dafc4c91f 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -9,7 +9,7 @@ using namespace BinSearch; template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { - if (DATA_TYPE > 0) { + if (DATA_TYPE == 0) { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; From baacfac22604061533d500625c3698b85d8cf5b1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 16:10:47 +0000 Subject: [PATCH 13/78] enable dequant 4bit kernel Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 106 +++++++++++++++++++++++++++++++ csrc/cpu_ops.cpp | 25 +++++++- csrc/cpu_ops.h | 88 +++++++++++++++++++++++++ csrc/pythonInterface.cpp | 34 ++++++++++ 4 files changed, 251 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a69d89f3a..3c4399873 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -115,3 +115,109 @@ def _( out = out.reshape(A.shape) return out + +@register_kernel("bitsandbytes::dequantize_4bit", "cpu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + if quant_type == "fp4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_fp4_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_fp4_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_fp4_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif quant_type == "nf4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_nf4_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_nf4_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_nf4_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + else: + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + out = out.reshape(-1, *shape[1:]).to(dtype) + + return out \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index dafc4c91f..83fa9db42 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -25,8 +25,29 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out } } } else { - // TODO: enable nf4 and fp4 - return; + #pragma omp parallel for + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + float scale = absmax[block_idx / blocksize]; + for (long long i = block_idx; i * 2 + 1 < block_end; i+=2) { + if (DATA_TYPE == 1) { + float up = dDequantizeFP4(A[i] >> 4) * scale; + float low = dDequantizeFP4(A[i] & 0x0F) * scale; + } elif (DATA_TYPE == 1) { + float up = dDequantizeNF4(A[i] >> 4) * scale; + float low = dDequantizeNF4(A[i] & 0x0F) * scale; + } + + if constexpr (std::is_same::value) { + out[i*2] = float_to_bf16(up); + out[i*2+1] = float_to_bf16(low); + } else { + out[i*2] = static_cast(up); + out[i*2+1] = static_cast(low); + } + } + } } } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 77791b3e6..9fd111719 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -28,6 +28,94 @@ static inline bf16_t float_to_bf16(float x) { return bf16_t{static_cast(r >> 16)}; } +inline float dDequantizeFP4(unsigned char val) { + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -0.25000000f; + else + return -0.16666667f; + else if ((val & 0b0001) == 1) + return -0.50000000f; + else + return -0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -1.00000000f; + else + return -0.66666667f; + else if ((val & 0b0001) == 1) + return -5.208333333e-03f; + else + return 0.00000000f; + else if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 0.25000000f; + else + return 0.16666667f; + else if ((val & 0b0001) == 1) + return 0.50000000f; + else + return 0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 1.00000000f; + else + return 0.66666667f; + else if ((val & 0b0001) == 1) + return 5.208333333e-03f; + else + return 0.00000000f; +} + +inline float dDequantizeNF4(unsigned char val) { + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; //*1111 + else + return 0.7229568362236023f; //*1110 + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; //*1101 + else + return 0.44070982933044434f; //*1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; //*1011 + else + return 0.24611230194568634f; //*1010 + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; //*1001 + else + return 0.07958029955625534f; //*1000 + + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; //*0111 + else + return -0.09105003625154495f; //*0110 + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; //*0101 + else + return -0.28444138169288635f; //*0100 + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; //*0011 + else + return -0.5250730514526367f; //*0010 + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; //*0001 + else + return -1.0f; //*0000 +} + template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 5056ccf0c..b69679cb7 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -859,4 +859,38 @@ void cdequantize_blockwise_cpu_fp16( ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } +void cdequantize_blockwise_cpu_fp4_fp32( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp4_bf16( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp4_fp16( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} +void cdequantize_blockwise_cpu_nf4_fp32( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_nf4_bf16( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_nf4_fp16( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} } From eec35212b86201b97c48af45dadca2514ceaeb49 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 16:11:53 +0000 Subject: [PATCH 14/78] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 83fa9db42..2ccb4b3f9 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -34,7 +34,7 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out if (DATA_TYPE == 1) { float up = dDequantizeFP4(A[i] >> 4) * scale; float low = dDequantizeFP4(A[i] & 0x0F) * scale; - } elif (DATA_TYPE == 1) { + } else if (DATA_TYPE == 1) { float up = dDequantizeNF4(A[i] >> 4) * scale; float low = dDequantizeNF4(A[i] & 0x0F) * scale; } From d7cc1c5e6bf3f4afcfe9b615f7990e200640616c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 16:21:36 +0000 Subject: [PATCH 15/78] fix typo Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 13 +++++++------ csrc/cpu_ops.cpp | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 3c4399873..76ffec650 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import ctypes as ct import logging @@ -139,7 +140,7 @@ def _( if quant_type == "fp4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp4_fp32( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -148,7 +149,7 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -157,7 +158,7 @@ def _( ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -167,7 +168,7 @@ def _( elif quant_type == "nf4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_nf4_fp32( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -176,7 +177,7 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -185,7 +186,7 @@ def _( ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 2ccb4b3f9..ec7317b7d 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -31,10 +31,11 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; for (long long i = block_idx; i * 2 + 1 < block_end; i+=2) { + float up, low; if (DATA_TYPE == 1) { float up = dDequantizeFP4(A[i] >> 4) * scale; float low = dDequantizeFP4(A[i] & 0x0F) * scale; - } else if (DATA_TYPE == 1) { + } else { float up = dDequantizeNF4(A[i] >> 4) * scale; float low = dDequantizeNF4(A[i] & 0x0F) * scale; } From 124b754e85f6df425c29fdf00ce0f7e4f94452c4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 19:30:26 +0000 Subject: [PATCH 16/78] fix dequantize Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 45 +++++++++++++++++++++++++++----- csrc/cpu_ops.cpp | 22 ++++++++-------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 76ffec650..d43d13dcd 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -145,7 +145,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( @@ -154,7 +154,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( @@ -163,7 +163,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif quant_type == "nf4": if dtype == torch.float32: @@ -173,7 +173,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( @@ -182,7 +182,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( @@ -191,7 +191,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) else: A = A.reshape(-1) @@ -221,4 +221,37 @@ def _( out = out.reshape(-1, *shape[1:]).to(dtype) + return out + +def dequant_nf4_x(A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype,): + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) return out \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ec7317b7d..cff55a3bf 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -27,25 +27,25 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out } else { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; - long long block_end = block_idx + valid_items; + long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i * 2 + 1 < block_end; i+=2) { + for (long long i = 0; i < valid_items; i+=2) { float up, low; + long long index = (i + block_idx) / 2; if (DATA_TYPE == 1) { - float up = dDequantizeFP4(A[i] >> 4) * scale; - float low = dDequantizeFP4(A[i] & 0x0F) * scale; + up = dDequantizeFP4(A[index] >> 4) * scale; + low = dDequantizeFP4(A[index] & 0x0F) * scale; } else { - float up = dDequantizeNF4(A[i] >> 4) * scale; - float low = dDequantizeNF4(A[i] & 0x0F) * scale; + up = dDequantizeNF4(A[index] >> 4) * scale; + low = dDequantizeNF4(A[index] & 0x0F) * scale; } if constexpr (std::is_same::value) { - out[i*2] = float_to_bf16(up); - out[i*2+1] = float_to_bf16(low); + out[i + block_idx] = float_to_bf16(up); + out[i+1 + block_idx] = float_to_bf16(low); } else { - out[i*2] = static_cast(up); - out[i*2+1] = static_cast(low); + out[i + block_idx] = static_cast(up); + out[i+1 + block_idx] = static_cast(low); } } } From 0f918c72cca2d60ee3e5f7272ccf74eb3fe31faf Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:12:53 +0000 Subject: [PATCH 17/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 7 +++--- bitsandbytes/backends/cpu/ops.py | 37 ++--------------------------- csrc/cpu_ops.cpp | 37 +++++++++++++++-------------- csrc/cpu_ops.h | 2 +- csrc/pythonInterface.cpp | 18 +++++++------- 5 files changed, 34 insertions(+), 67 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0aba814c1..158088c97 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -434,10 +434,9 @@ def matmul_4bit( assert quant_state is not None # Change dtype to bfloat16 on CPU if A.device.type == "cpu": - if quant_state.dtype == torch.float32: - quant_state.dtype = torch.bfloat16 - if hasattr(quant_state, "state2") and quant_state.state2.dtype == torch.float32: - quant_state.state2.dtype = torch.bfloat16 + quant_state.dtype = A.dtype + if hasattr(quant_state, "state2"): + quant_state.state2.dtype = A.dtype if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d43d13dcd..e92c9b3f4 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -135,7 +135,8 @@ def _( # Enable non uint8 dtype if A.dtype != torch.uint8: A = A.view(torch.uint8) - + + A = A.reshape(-1) out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if quant_type == "fp4": if dtype == torch.float32: @@ -194,7 +195,6 @@ def _( ct.c_longlong(out.numel()), ) else: - A = A.reshape(-1) # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) n = out_dq.numel() @@ -222,36 +222,3 @@ def _( out = out.reshape(-1, *shape[1:]).to(dtype) return out - -def dequant_nf4_x(A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype,): - out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) - A = A.reshape(-1) - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] - else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - return out \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index cff55a3bf..f2c6f0690 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -7,15 +7,15 @@ using namespace BinSearch; template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, +void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n) { if (DATA_TYPE == 0) { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i < block_end; i++) { + for (long long i = block_idx; i < block_end; ++i) { float v = code[A[i]] * scale; if constexpr (std::is_same::value) { out[i] = float_to_bf16(v); @@ -29,23 +29,24 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); float scale = absmax[block_idx / blocksize]; - for (long long i = 0; i < valid_items; i+=2) { - float up, low; - long long index = (i + block_idx) / 2; - if (DATA_TYPE == 1) { - up = dDequantizeFP4(A[index] >> 4) * scale; - low = dDequantizeFP4(A[index] & 0x0F) * scale; - } else { - up = dDequantizeNF4(A[index] >> 4) * scale; - low = dDequantizeNF4(A[index] & 0x0F) * scale; - } - + for (long long i = 0; i < valid_items; i += 2) { + long long byte_index = (block_idx + i) >> 1; + unsigned char byte = A[byte_index]; + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) + : dDequantizeNF4(byte & 0x0F)) * scale; + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) + : dDequantizeNF4(byte >> 4)) * scale; if constexpr (std::is_same::value) { - out[i + block_idx] = float_to_bf16(up); - out[i+1 + block_idx] = float_to_bf16(low); + out[block_idx + i] = float_to_bf16(v0); } else { - out[i + block_idx] = static_cast(up); - out[i+1 + block_idx] = static_cast(low); + out[block_idx + i] = static_cast(v0); + } + if (i + 1 < valid_items) { + if constexpr (std::is_same::value) { + out[block_idx + i + 1] = float_to_bf16(v1); + } else { + out[block_idx + i + 1] = static_cast(v1); + } } } } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 9fd111719..3ad7b3ac2 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -117,6 +117,6 @@ inline float dDequantizeNF4(unsigned char val) { } template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); +void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index b69679cb7..33fcb6041 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -843,53 +843,53 @@ void cquantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_fp32( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp32( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_bf16( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp16( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp32( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_bf16( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp16( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } From e1a8b20d262eab013489a6f3b31b8e48a4a3a760 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:15:19 +0000 Subject: [PATCH 18/78] fix Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f2c6f0690..46e238386 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -105,31 +105,31 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, From eab45c8565f68224b9f896376a257d561b06f5fc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:20:42 +0000 Subject: [PATCH 19/78] test Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index e92c9b3f4..e49543867 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -185,6 +185,11 @@ def _( ct.c_longlong(blocksize), ct.c_longlong(out.numel()), ) + out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) + out = out.reshape(shape) + out_2 = out_2.reshape(shape) + if torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( None, @@ -222,3 +227,37 @@ def _( out = out.reshape(-1, *shape[1:]).to(dtype) return out + +def dequantize_nf4_test( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +): + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + return out From d9f5dd8e215c16e7b12d0bf994b8ba1530bd52b4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:22:49 +0000 Subject: [PATCH 20/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index e49543867..6b131be90 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -6,6 +6,7 @@ from bitsandbytes.functional import get_ptr +from ..util import CODE from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib From 070f8a082b623bdccd8755f8ea2152118223e8d8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:23:40 +0000 Subject: [PATCH 21/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 6b131be90..ecc744e2d 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -6,7 +6,7 @@ from bitsandbytes.functional import get_ptr -from ..util import CODE +from ..utils import CODE from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib From a84addfe5d6410f5752faba4f12f56c5b0b1e2ee Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:25:48 +0000 Subject: [PATCH 22/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index ecc744e2d..d6398b06b 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -189,7 +189,7 @@ def _( out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) out = out.reshape(shape) out_2 = out_2.reshape(shape) - if torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( From c4bb6607767668be20eb00d43aa91297055a4ca9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:33:13 +0000 Subject: [PATCH 23/78] fix Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 46e238386..03c0af795 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -32,10 +32,10 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, for (long long i = 0; i < valid_items; i += 2) { long long byte_index = (block_idx + i) >> 1; unsigned char byte = A[byte_index]; - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) - : dDequantizeNF4(byte & 0x0F)) * scale; - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) : dDequantizeNF4(byte >> 4)) * scale; + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) + : dDequantizeNF4(byte & 0x0F)) * scale; if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); } else { From 4ba13fd37f4d741648712abaab35edeab7039dd8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:40:55 +0000 Subject: [PATCH 24/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d6398b06b..99bc21ca0 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -137,6 +137,10 @@ def _( if A.dtype != torch.uint8: A = A.view(torch.uint8) + # TODO: support half precision absmax + if absmax.dtype != torch.float32: + absmax = absmax.float() + A = A.reshape(-1) out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if quant_type == "fp4": From c0d05ec1e03c24c8717de979fe49ca76d0d18733 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:22:02 +0000 Subject: [PATCH 25/78] change input param Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 2 - bitsandbytes/backends/cpu/ops.py | 39 +++++--- csrc/cpu_ops.cpp | 145 +++++++++++++++++++--------- csrc/cpu_ops.h | 2 +- csrc/pythonInterface.cpp | 18 ++-- 5 files changed, 137 insertions(+), 69 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 158088c97..061b4d1b8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -435,8 +435,6 @@ def matmul_4bit( # Change dtype to bfloat16 on CPU if A.device.type == "cpu": quant_state.dtype = A.dtype - if hasattr(quant_state, "state2"): - quant_state.state2.dtype = A.dtype if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 99bc21ca0..c38c22583 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -86,7 +86,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_bf16( @@ -95,7 +96,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp16( @@ -104,7 +106,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) else: out = code[A.reshape(-1).int()] @@ -141,7 +144,7 @@ def _( if absmax.dtype != torch.float32: absmax = absmax.float() - A = A.reshape(-1) + A = A.reshape(shape[0], shape[1] // 2) out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if quant_type == "fp4": if dtype == torch.float32: @@ -151,7 +154,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( @@ -160,7 +164,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( @@ -169,7 +174,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif quant_type == "nf4": if dtype == torch.float32: @@ -179,7 +185,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( @@ -188,7 +195,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) out = out.reshape(shape) @@ -202,10 +210,12 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) else: # Map nf4 to [-1, 1] + A = A.reshape(-1) out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) n = out_dq.numel() out_dq[1::2] = A & 0xF @@ -229,7 +239,7 @@ def _( else: out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - out = out.reshape(-1, *shape[1:]).to(dtype) + out = out.reshape(-1, *shape[1:]).to(dtype) return out @@ -266,3 +276,10 @@ def dequantize_nf4_test( out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) return out + + +def _reverse_4bit_compress_format(weight: torch.Tensor): + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 03c0af795..ef5a27729 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -6,53 +6,97 @@ using namespace BinSearch; +// 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. +// DATA_TYPE: 1 = FP4, 2 = NF4 template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, - long long blocksize, long long n) { - if (DATA_TYPE == 0) { - #pragma omp parallel for - for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); - long long block_end = block_idx + valid_items; - float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i < block_end; ++i) { - float v = code[A[i]] * scale; +inline void dequantizeBlockwise4bitCpu(float* code, + unsigned char* A, + const float* absmax, + T* out, + long long blocksize, + long long m, + long long n) { + static_assert(DATA_TYPE == 1 || DATA_TYPE == 2, + "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); + if (blocksize <= 0 || n <= 0) return; + +#if defined(__AVX512F__) && defined(__AVX512BW__) && defined(TEST_BUG) + // AVX512 optimized branch (placeholder) + // DATA_TYPE: 1 = FP4, 2 = NF4 + if (1 == 0) {return;} +#else + // Scalar fallback branch + long long total = m * n; + #pragma omp parallel for + for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { + long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + float scale = absmax[block_idx / blocksize]; + for (long long i = 0; i < valid_items; i += 2) { + long long byte_index = (block_idx + i) >> 1; + unsigned char byte = A[byte_index]; + + // High nibble first (matches previous code logic) + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) + : dDequantizeNF4(byte >> 4)) * scale; + // Low nibble second + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) + : dDequantizeNF4(byte & 0x0F)) * scale; + + if constexpr (std::is_same::value) { + out[block_idx + i] = float_to_bf16(v0); + } else { + out[block_idx + i] = static_cast(v0); + } + + if (i + 1 < valid_items) { if constexpr (std::is_same::value) { - out[i] = float_to_bf16(v); + out[block_idx + i + 1] = float_to_bf16(v1); } else { - out[i] = static_cast(v); + out[block_idx + i + 1] = static_cast(v1); } } } - } else { + } +#endif +} + + +template +void dequantizeBlockwiseCpu(float* code, + unsigned char* A, + const float* absmax, + T* out, + long long blocksize, + long long m, + long long n) { + static_assert(DATA_TYPE == 0 || DATA_TYPE == 1 || DATA_TYPE == 2, + "dequantizeBlockwiseCpu: invalid DATA_TYPE"); + if (blocksize <= 0 || m <= 0 || n <= 0) return; + + if constexpr (DATA_TYPE == 0) { + // 8-bit path + long long total = (m * n) >> 1; #pragma omp parallel for - for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); + for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { + long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; - for (long long i = 0; i < valid_items; i += 2) { - long long byte_index = (block_idx + i) >> 1; - unsigned char byte = A[byte_index]; - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) - : dDequantizeNF4(byte >> 4)) * scale; - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) - : dDequantizeNF4(byte & 0x0F)) * scale; + for (long long i = block_idx; i < block_end; ++i) { + float v = code[A[i]] * scale; if constexpr (std::is_same::value) { - out[block_idx + i] = float_to_bf16(v0); + out[i] = float_to_bf16(v); } else { - out[block_idx + i] = static_cast(v0); - } - if (i + 1 < valid_items) { - if constexpr (std::is_same::value) { - out[block_idx + i + 1] = float_to_bf16(v1); - } else { - out[block_idx + i + 1] = static_cast(v1); - } + out[i] = static_cast(v); } } } + } else { + // 4-bit helper (FP4 / NF4) + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } } + void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below @@ -105,31 +149,40 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 3ad7b3ac2..0ea071e2d 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -117,6 +117,6 @@ inline float dDequantizeNF4(unsigned char val) { } template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); +void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 33fcb6041..2ab6920da 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -843,53 +843,53 @@ void cquantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } From 62a16a6e8fb4611508d94c11fb4429a4322e8ea0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:27:19 +0000 Subject: [PATCH 26/78] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 5 ++--- csrc/pythonInterface.cpp | 26 +++++++++++--------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ef5a27729..8ad2626d2 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -9,8 +9,7 @@ using namespace BinSearch; // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 2 = NF4 template -inline void dequantizeBlockwise4bitCpu(float* code, - unsigned char* A, +inline void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, @@ -92,7 +91,7 @@ void dequantizeBlockwiseCpu(float* code, } } else { // 4-bit helper (FP4 / NF4) - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } } diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 2ab6920da..7cd74b844 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,52 +845,48 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_fp4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_fp4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_fp4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_nf4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_nf4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_nf4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } } From d9ad828244b0a30b5640d42cb4a5e9076e17d8b6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:30:41 +0000 Subject: [PATCH 27/78] fix input param Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 8ad2626d2..e7d80677f 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -169,19 +169,19 @@ template void dequantizeBlockwiseCpu( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, From 09ed6cbf455de0445cbf59bbb77100c7405c6d60 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:38:30 +0000 Subject: [PATCH 28/78] spliut 8bit and 4bit Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 9 ++--- csrc/cpu_ops.cpp | 65 +++++++++++--------------------- csrc/cpu_ops.h | 8 ++-- csrc/pythonInterface.cpp | 18 ++++----- 4 files changed, 38 insertions(+), 62 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index c38c22583..33060718f 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -86,8 +86,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_longlong(A.numel()), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_bf16( @@ -96,8 +95,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_longlong(A.numel()), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp16( @@ -106,8 +104,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_longlong(A.numel()), ) else: out = code[A.reshape(-1).int()] diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e7d80677f..091925fca 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -9,13 +9,13 @@ using namespace BinSearch; // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 2 = NF4 template -inline void dequantizeBlockwise4bitCpu(unsigned char* A, +void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n) { - static_assert(DATA_TYPE == 1 || DATA_TYPE == 2, + static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || n <= 0) return; @@ -60,38 +60,29 @@ inline void dequantizeBlockwise4bitCpu(unsigned char* A, } -template -void dequantizeBlockwiseCpu(float* code, +template +void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, - long long m, long long n) { - static_assert(DATA_TYPE == 0 || DATA_TYPE == 1 || DATA_TYPE == 2, - "dequantizeBlockwiseCpu: invalid DATA_TYPE"); - if (blocksize <= 0 || m <= 0 || n <= 0) return; - - if constexpr (DATA_TYPE == 0) { - // 8-bit path - long long total = (m * n) >> 1; - #pragma omp parallel for - for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { - long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); - long long block_end = block_idx + valid_items; - float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i < block_end; ++i) { - float v = code[A[i]] * scale; - if constexpr (std::is_same::value) { - out[i] = float_to_bf16(v); - } else { - out[i] = static_cast(v); - } + if (blocksize <= 0 || n <= 0) return; + // 8-bit path + long long total = (m * n) >> 1; + #pragma omp parallel for + for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { + long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + long long block_end = block_idx + valid_items; + float scale = absmax[block_idx / blocksize]; + for (long long i = block_idx; i < block_end; ++i) { + float v = code[A[i]] * scale; + if constexpr (std::is_same::value) { + out[i] = float_to_bf16(v); + } else { + out[i] = static_cast(v); } } - } else { - // 4-bit helper (FP4 / NF4) - dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } } @@ -147,25 +138,11 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long // TEMPLATE DEFINITIONS //============================================================== -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( +template void dequantizeBlockwise8bitCpu( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); - -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( +template void dequantizeBlockwise8bitCpu( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); - -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( +template void dequantizeBlockwise8bitCpu( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 0ea071e2d..092261c4f 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -10,9 +10,8 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); typedef enum DataType_t { - General8bit = 0, + NF4 = 0, FP4 = 1, - NF4 = 2, } DataType_t; using fp16_t = _Float16; @@ -116,7 +115,10 @@ inline float dDequantizeNF4(unsigned char val) { return -1.0f; //*0000 } +template +void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); + template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 7cd74b844..127d147a0 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,48 +845,48 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } } From a3f7b61128bf051e824b9979dbb3f166dfdced56 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:39:36 +0000 Subject: [PATCH 29/78] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 +++--- csrc/cpu_ops.h | 2 +- csrc/pythonInterface.cpp | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 091925fca..2bc380a91 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -139,11 +139,11 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 092261c4f..047466f7a 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -116,7 +116,7 @@ inline float dDequantizeNF4(unsigned char val) { } template -void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); template void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 127d147a0..894432ede 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,17 +845,17 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp32( From 47084701d5a736435e9ceaf2f74033b5229cdf2c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:40:36 +0000 Subject: [PATCH 30/78] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 2bc380a91..55d129ceb 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -69,10 +69,9 @@ void dequantizeBlockwise8bitCpu(float* code, long long n) { if (blocksize <= 0 || n <= 0) return; // 8-bit path - long long total = (m * n) >> 1; #pragma omp parallel for - for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { - long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; for (long long i = block_idx; i < block_end; ++i) { From 1dfe9f71648079fa33532c0c72e8eb4766cf896c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:42:12 +0000 Subject: [PATCH 31/78] fix input params Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 33060718f..c4475eef1 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -146,7 +146,6 @@ def _( if quant_type == "fp4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp4_fp32( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -156,7 +155,6 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -166,7 +164,6 @@ def _( ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -177,7 +174,6 @@ def _( elif quant_type == "nf4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_nf4_fp32( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -187,7 +183,6 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -202,7 +197,6 @@ def _( import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), From 00289c429dc28a552894bf086e0b1349d07af74f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:43:43 +0000 Subject: [PATCH 32/78] fix input params Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 10 +++++----- csrc/pythonInterface.cpp | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 55d129ceb..e9c477893 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -10,11 +10,11 @@ using namespace BinSearch; // DATA_TYPE: 1 = FP4, 2 = NF4 template void dequantizeBlockwise4bitCpu(unsigned char* A, - const float* absmax, - T* out, - long long blocksize, - long long m, - long long n) { + const float* absmax, + T* out, + long long blocksize, + long long m, + long long n) { static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || n <= 0) return; diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 894432ede..62d3bf826 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -861,32 +861,32 @@ void cdequantize_blockwise_cpu_fp16( void cdequantize_blockwise_cpu_fp4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } } From a2578baabaaaf178f331e5357fd2f3ac3ed6654b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 13:04:05 +0000 Subject: [PATCH 33/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 +++- csrc/cpu_ops.cpp | 4 ++-- csrc/pythonInterface.cpp | 12 ++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index c4475eef1..acf4caa34 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -190,7 +190,7 @@ def _( ct.c_longlong(shape[0]), ct.c_longlong(shape[1]), ) - out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) + out_2 = dequantize_nf4_test(A.reshape(-1), absmax, blocksize, quant_type, shape, dtype) out = out.reshape(shape) out_2 = out_2.reshape(shape) if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): @@ -266,6 +266,8 @@ def dequantize_nf4_test( else: out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + out = out.reshape(-1, *shape[1:]).to(dtype) + return out diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e9c477893..a0c9bd50c 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -7,7 +7,7 @@ using namespace BinSearch; // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. -// DATA_TYPE: 1 = FP4, 2 = NF4 +// DATA_TYPE: 1 = FP4, 0 = NF4 template void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, @@ -17,7 +17,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, long long n) { static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); - if (blocksize <= 0 || n <= 0) return; + if (blocksize <= 0 || m < 0 || n <= 0) return; #if defined(__AVX512F__) && defined(__AVX512BW__) && defined(TEST_BUG) // AVX512 optimized branch (placeholder) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 62d3bf826..fd89a626e 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -859,33 +859,33 @@ void cdequantize_blockwise_cpu_fp16( } void cdequantize_blockwise_cpu_fp4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } From 72033dc1a39f803c40a42b0f31f0ae6bddc7b2af Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 13:19:58 +0000 Subject: [PATCH 34/78] fix typo Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- csrc/pythonInterface.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index acf4caa34..2a6014940 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -142,7 +142,7 @@ def _( absmax = absmax.float() A = A.reshape(shape[0], shape[1] // 2) - out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + out = torch.empty(shape, dtype=dtype, device=A.device) if quant_type == "fp4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp4_fp32( diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index fd89a626e..d9914951f 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -843,17 +843,17 @@ void cquantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } From 1c20ae831e4371f62e668ff9df0216e492edb48b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 15:34:41 +0000 Subject: [PATCH 35/78] enable dequant4bit Signed-off-by: jiqing-feng --- CMakeLists.txt | 12 ++ bitsandbytes/backends/cpu/ops.py | 212 ++++++++++++++----------------- csrc/cpu_ops.cpp | 127 +++++++++++++++++- csrc/cpu_ops.h | 1 + 4 files changed, 229 insertions(+), 123 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5abfca78..8d4a492c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -271,6 +271,18 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if (BUILD_CPU) target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) + include(CheckCXXCompilerFlag) + + check_cxx_compiler_flag(-mavx512f HAS_AVX512F) + check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16) + + if(HAS_AVX512F) + target_compile_options(bitsandbytes PRIVATE -mavx512f) + endif() + + if(HAS_AVX512BF16) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + endif() endif() if(BUILD_CUDA) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 2a6014940..57cd830c2 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -27,6 +27,12 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) +def _reverse_4bit_compress_format(weight: torch.Tensor): + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") @@ -118,121 +124,95 @@ def _( return out -@register_kernel("bitsandbytes::dequantize_4bit", "cpu") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - torch._check_is_size(blocksize) - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - # Enable non uint8 dtype - if A.dtype != torch.uint8: - A = A.view(torch.uint8) - - # TODO: support half precision absmax - if absmax.dtype != torch.float32: - absmax = absmax.float() - - A = A.reshape(shape[0], shape[1] // 2) - out = torch.empty(shape, dtype=dtype, device=A.device) - if quant_type == "fp4": - if dtype == torch.float32: - lib.cdequantize_blockwise_cpu_fp4_fp32( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_cpu_fp4_bf16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif dtype == torch.float16: - lib.cdequantize_blockwise_cpu_fp4_fp16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif quant_type == "nf4": - if dtype == torch.float32: - lib.cdequantize_blockwise_cpu_nf4_fp32( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_cpu_nf4_bf16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - out_2 = dequantize_nf4_test(A.reshape(-1), absmax, blocksize, quant_type, shape, dtype) - out = out.reshape(shape) - out_2 = out_2.reshape(shape) - if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): - import pdb; pdb.set_trace() - elif dtype == torch.float16: - lib.cdequantize_blockwise_cpu_nf4_fp16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - else: - # Map nf4 to [-1, 1] - A = A.reshape(-1) - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] + @register_kernel("bitsandbytes::dequantize_4bit", "cpu") + def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + # TODO: support half precision absmax + if absmax.dtype != torch.float32: + absmax = absmax.float() + + A = _reverse_4bit_compress_format(A) + A = A.reshape(shape[0], shape[1] // 2) + out = torch.empty(shape, dtype=dtype, device=A.device) + if quant_type == "fp4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_fp4_fp32( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_fp4_bf16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_fp4_fp16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif quant_type == "nf4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_nf4_fp32( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_nf4_bf16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) + if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + import pdb; pdb.set_trace() + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_nf4_fp16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - - out = out.reshape(-1, *shape[1:]).to(dtype) + raise ValueError - return out + return out def dequantize_nf4_test( A: torch.Tensor, @@ -270,9 +250,3 @@ def dequantize_nf4_test( return out - -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index a0c9bd50c..5beeaf3d4 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,6 +5,76 @@ using namespace BinSearch; +// #if defined(__AVX512F__) +#if 1 +#include + +inline __m256i cvt_fp32_to_fp16(const __m512 src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + +inline __m256i cvt_fp32_to_bf16(const __m512 src) { + #if defined(__AVX512BF16__) + return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); + #else + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); + #endif +} + +static inline __m512 set_nf4_lut() { + return _mm512_set_ps( + 1.0f, + 0.7229568362236023, + 0.5626170039176941, + 0.44070982933044434, + 0.33791524171829224, + 0.24611230194568634, + 0.16093020141124725, + 0.07958029955625534, + 0.0f, + -0.09105003625154495, + -0.18477343022823334, + -0.28444138169288635, + -0.39491748809814453, + -0.5250730514526367, + -0.6961928009986877, + -1.0f); +} +static inline __m512 set_fp4_lut() { + return _mm512_set_ps( + 0.0000f, + 5.208333333e-03f, + 0.66666667f, + 1.0000f, + 0.33333333f, + 0.5000f, + 0.16666667f, + 0.2500f, + 0.0000f, + -5.208333333e-03f, + -0.66666667f, + -1.0000f, + -0.33333333f, + -0.5000f, + -0.16666667f, + -0.2500f); +} +#endif // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 0 = NF4 @@ -19,10 +89,59 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || m < 0 || n <= 0) return; -#if defined(__AVX512F__) && defined(__AVX512BW__) && defined(TEST_BUG) - // AVX512 optimized branch (placeholder) - // DATA_TYPE: 1 = FP4, 2 = NF4 - if (1 == 0) {return;} +// #if defined(__AVX512F__) && defined(TEST_BUG) +# if 1 + auto dim_0 = m; + auto dim_1 = n; + auto input_dim_1 = dim_1 >> 1; + using Tcomp = float; + constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16 + if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { + __m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut(); + constexpr auto k_step = VEC_LEN / 2; // 8 + // auto dequant_loop = ThreadedLoop<2>({{dim_0}}, /* loop_scheme */ "A"); + // dequant_loop( + // [&](int* idx) { + // int block_idx = idx[0]; + #pragma omp parallel for + for (int block_idx = 0; block_idx < dim_0; ++block_idx) { + for (int k = 0; k < input_dim_1; k += k_step) { + // Load 64 bits of nf4 data and a single scale data + // auto p = A[block_idx * input_dim_1 + k]; + uint8_t* p = &A[block_idx * input_dim_1 + k]; + uint64_t packed; + std::memcpy(&packed, p, sizeof(uint64_t)); + auto scale_idx = k * 2 / blocksize; + auto vscales = _mm512_set1_ps((float)absmax[block_idx * blocksize + scale_idx]); + // uint64_t packed = reinterpret_cast(p)[0]; + // unpack nf4 data to 32-bit integers + uint64_t high = 0; + uint64_t low = 0; + for (int i = 0; i < 8; ++i) { + low |= ((packed >> (i * 4)) & 0xf) << (i * 8); + high |= ((packed >> (i * 4 + 32)) & 0xf) << (i * 8); + } + __m128i packed_128 = _mm_set_epi64x(high, low); + __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); + // Table look-up + __m512 vout = _mm512_permutexvar_ps(vint32, lut); + // Apply scale + vout = _mm512_mul_ps(vout, vscales); + // Store results + // auto pout = out[block_idx * dim_1 + k * 2]; + T* pout = &out[block_idx * dim_1 + k * 2]; // out[block_idx][k/k_step] + if constexpr (std::is_same()) { + _mm512_storeu_ps(pout, vout); + } else if constexpr (std::is_same()) { + _mm256_storeu_si256( + (__m256i*)pout, cvt_fp32_to_bf16(vout)); + } else if constexpr (std::is_same()) { + _mm256_storeu_si256( + (__m256i*)pout, cvt_fp32_to_fp16(vout)); + } + } + } + } #else // Scalar fallback branch long long total = m * n; diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 047466f7a..6be5a864c 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -115,6 +115,7 @@ inline float dDequantizeNF4(unsigned char val) { return -1.0f; //*0000 } + template void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); From 7552fe22e94ef1cdf472df84587c048615fad3d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 16:05:42 +0000 Subject: [PATCH 36/78] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 6 +++--- csrc/cpu_ops.cpp | 24 ++++++++---------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 57cd830c2..c5a45c914 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -197,9 +197,9 @@ def _( ct.c_longlong(shape[0]), ct.c_longlong(shape[1]), ) - out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) - if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): - import pdb; pdb.set_trace() + # out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) + # if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + # import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( get_ptr(A), diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 5beeaf3d4..f9c7a5364 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,8 +5,7 @@ using namespace BinSearch; -// #if defined(__AVX512F__) -#if 1 +#if defined(__AVX512F__) #include inline __m256i cvt_fp32_to_fp16(const __m512 src) { @@ -89,31 +88,25 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || m < 0 || n <= 0) return; -// #if defined(__AVX512F__) && defined(TEST_BUG) -# if 1 - auto dim_0 = m; - auto dim_1 = n; - auto input_dim_1 = dim_1 >> 1; +#if defined(__AVX512F__) && defined(TEST_BUG) + long long dim_0 = m; + long long dim_1 = n; + long long input_dim_1 = dim_1 >> 1; + long long absmax_dim_1 = dim_1 / blocksize using Tcomp = float; constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16 if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { __m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut(); constexpr auto k_step = VEC_LEN / 2; // 8 - // auto dequant_loop = ThreadedLoop<2>({{dim_0}}, /* loop_scheme */ "A"); - // dequant_loop( - // [&](int* idx) { - // int block_idx = idx[0]; #pragma omp parallel for for (int block_idx = 0; block_idx < dim_0; ++block_idx) { for (int k = 0; k < input_dim_1; k += k_step) { // Load 64 bits of nf4 data and a single scale data - // auto p = A[block_idx * input_dim_1 + k]; uint8_t* p = &A[block_idx * input_dim_1 + k]; uint64_t packed; std::memcpy(&packed, p, sizeof(uint64_t)); auto scale_idx = k * 2 / blocksize; - auto vscales = _mm512_set1_ps((float)absmax[block_idx * blocksize + scale_idx]); - // uint64_t packed = reinterpret_cast(p)[0]; + auto vscales = _mm512_set1_ps((float)absmax[block_idx * absmax_dim_1 + scale_idx]); // unpack nf4 data to 32-bit integers uint64_t high = 0; uint64_t low = 0; @@ -128,8 +121,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, // Apply scale vout = _mm512_mul_ps(vout, vscales); // Store results - // auto pout = out[block_idx * dim_1 + k * 2]; - T* pout = &out[block_idx * dim_1 + k * 2]; // out[block_idx][k/k_step] + T* pout = &out[block_idx * dim_1 + k * 2]; if constexpr (std::is_same()) { _mm512_storeu_ps(pout, vout); } else if constexpr (std::is_same()) { From 8b32a39c34464c44120128d7d0982371087f9d70 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 16:09:33 +0000 Subject: [PATCH 37/78] fix Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f9c7a5364..ddfd64aa7 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -88,11 +88,11 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || m < 0 || n <= 0) return; -#if defined(__AVX512F__) && defined(TEST_BUG) +#if defined(__AVX512F__) long long dim_0 = m; long long dim_1 = n; long long input_dim_1 = dim_1 >> 1; - long long absmax_dim_1 = dim_1 / blocksize + long long absmax_dim_1 = dim_1 / blocksize; using Tcomp = float; constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16 if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { From 8f1cc3699be96062564b8296aa2382c3356e71d3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 17:47:23 +0000 Subject: [PATCH 38/78] fix reverse Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 7 ------- csrc/cpu_ops.cpp | 10 +++++++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index c5a45c914..a716c7580 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -27,12 +27,6 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") @@ -147,7 +141,6 @@ def _( if absmax.dtype != torch.float32: absmax = absmax.float() - A = _reverse_4bit_compress_format(A) A = A.reshape(shape[0], shape[1] // 2) out = torch.empty(shape, dtype=dtype, device=A.device) if quant_type == "fp4": diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ddfd64aa7..a46799c58 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,6 +5,8 @@ using namespace BinSearch; +#define __AVX512F__ + #if defined(__AVX512F__) #include @@ -110,9 +112,11 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, // unpack nf4 data to 32-bit integers uint64_t high = 0; uint64_t low = 0; - for (int i = 0; i < 8; ++i) { - low |= ((packed >> (i * 4)) & 0xf) << (i * 8); - high |= ((packed >> (i * 4 + 32)) & 0xf) << (i * 8); + for (int i = 0; i < 4; ++i) { + low |= ((packed >> (2*i * 4)) & 0xf) << ((2*i+1) * 8); + low |= ((packed >> ((2*i+1) * 4)) & 0xf) << (2*i * 8); + high |= ((packed >> (2*i * 4 + 32)) & 0xf) << ((2*i+1) * 8); + high |= ((packed >> ((2*i+1) * 4 + 32)) & 0xf) << (2*i * 8); } __m128i packed_128 = _mm_set_epi64x(high, low); __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); From 49d242a82751c45bb3ad04aae6eb740d62eecc40 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 3 Nov 2025 09:15:24 +0000 Subject: [PATCH 39/78] fix dequant 4bit fallback path Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index a46799c58..a797b0ab2 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,7 +5,6 @@ using namespace BinSearch; -#define __AVX512F__ #if defined(__AVX512F__) #include @@ -137,8 +136,9 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, } } } + return; } -#else +#endif // Scalar fallback branch long long total = m * n; #pragma omp parallel for @@ -171,7 +171,6 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, } } } -#endif } From 4a9a6dc1817bc38110ac94656bd606024a4b953b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 3 Nov 2025 09:42:55 +0000 Subject: [PATCH 40/78] fix fp4 dequant Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index a797b0ab2..f8082fb7a 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -57,22 +57,22 @@ static inline __m512 set_nf4_lut() { } static inline __m512 set_fp4_lut() { return _mm512_set_ps( + -0.2500f, + -0.16666667f, + -0.5000f, + -0.33333333f, + -1.0000f, + -0.66666667f, + -5.208333333e-03f, 0.0000f, - 5.208333333e-03f, - 0.66666667f, - 1.0000f, - 0.33333333f, - 0.5000f, - 0.16666667f, 0.2500f, - 0.0000f, - -5.208333333e-03f, - -0.66666667f, - -1.0000f, - -0.33333333f, - -0.5000f, - -0.16666667f, - -0.2500f); + 0.16666667f, + 0.5000f, + 0.33333333f, + 1.0000f, + 0.66666667f, + 5.208333333e-03f, + 0.0000f); } #endif From d7e981d920c8ac79c208a3fb56f8493203cdd386 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 5 Nov 2025 12:52:55 +0000 Subject: [PATCH 41/78] rm _Float16 Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 ++++++ csrc/cpu_ops.h | 49 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f8082fb7a..f590bc6ab 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -158,6 +158,8 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); + } else if constexpr (std::is_same::value) { + out[block_idx + i] = float_to_fp16(v0); } else { out[block_idx + i] = static_cast(v0); } @@ -165,6 +167,8 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, if (i + 1 < valid_items) { if constexpr (std::is_same::value) { out[block_idx + i + 1] = float_to_bf16(v1); + } else if constexpr (std::is_same::value) { + out[block_idx + i + 1] = float_to_fp16(v1); } else { out[block_idx + i + 1] = static_cast(v1); } @@ -192,6 +196,8 @@ void dequantizeBlockwise8bitCpu(float* code, float v = code[A[i]] * scale; if constexpr (std::is_same::value) { out[i] = float_to_bf16(v); + } else if constexpr (std::is_same::value) { + out[i] = float_to_fp16(v); } else { out[i] = static_cast(v); } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 6be5a864c..fea894d79 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,11 +1,8 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H -#include -#include #include #include -#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -14,7 +11,9 @@ typedef enum DataType_t { FP4 = 1, } DataType_t; -using fp16_t = _Float16; +struct fp16_t { + uint16_t v; +}; struct bf16_t { uint16_t v; @@ -27,6 +26,48 @@ static inline bf16_t float_to_bf16(float x) { return bf16_t{static_cast(r >> 16)}; } +static inline fp16_t float_to_fp16(float x) { + uint32_t bits; + std::memcpy(&bits, &x, 4); + uint32_t sign = (bits >> 31) & 0x1; + uint32_t exp = (bits >> 23) & 0xFF; + uint32_t mant = bits & 0x7FFFFF; + + uint16_t h; + if (exp == 0xFF) { // Inf / NaN + uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa + h = (sign << 15) | (0x1F << 10) | mant16; + } else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142) + h = (sign << 15) | (0x1F << 10); // Inf + } else if (exp < 0x71) { // subnormal or zero (exp_f < 113) + if (exp < 0x67) { // too small -> zero (exp_f < 103) + h = (sign << 15); + } else { + // subnormal: implicit leading 1 + uint32_t shift = 0x71 - exp; + uint32_t mant_with_hidden = mant | 0x800000; + // add rounding bias before shifting (23-10 =13 bits to drop + shift) + uint32_t rounded = (mant_with_hidden + (1u << (shift + 12))) >> (shift + 13); + h = (sign << 15) | (uint16_t)rounded; + } + } else { + // normalized + uint32_t exp_h = exp - 127 + 15; + // round mantissa: add 2^(23-10-1) = 0x1000 + uint32_t mant_rounded = mant + 0x00001000; + if (mant_rounded & 0x00800000) { // mantissa overflow after rounding + mant_rounded = 0; + ++exp_h; + if (exp_h >= 0x1F) { // overflow to Inf + h = (sign << 15) | (0x1F << 10); + return fp16_t{h}; + } + } + h = (sign << 15) | ((uint16_t)exp_h << 10) | ((uint16_t)(mant_rounded >> 13)); + } + return fp16_t{h}; +} + inline float dDequantizeFP4(unsigned char val) { if ((val & 0b1000) == 8) if ((val & 0b0100) == 4) From 48739b0945e5fa877f6293588c60bb3d7ce778e5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 6 Nov 2025 09:20:02 +0000 Subject: [PATCH 42/78] tmp codes Signed-off-by: jiqing-feng --- bitsandbytes/backends/utils.py | 33 +++++++++++++++++++++++++++++++++ csrc/cpu_ops.cpp | 21 ++++++++++++--------- csrc/cpu_ops.h | 3 +++ csrc/pythonInterface.cpp | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index ec96a440c..0d7b5cb4d 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -81,4 +81,37 @@ def get_gaudi_sw_version(): return version.parse(output.stdout.split("\n")[0].split()[-1]) +def convert_weight_packed_for_cpu(qweight: torch.Tensor, + scales: torch.Tensor, + block_n: int = 32): + """ + qweight: (K * N / 2) uint8 + return: packed_weight + """ + assert qweight.dtype == torch.uint8, "qweight must be uint8" + qweight = qweight.reshape(-1) + unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=A.device) + unpacked_w[1::2] = qweight & 0xF + unpacked_w[::2] = qweight >> 4 + qweight_final = unpacked_w.reshape(shape).transpose(-1, -2).to(torch.uint8) # (*, N, K) + # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit + assert len(qweight_final.shape) == 2 + N, K = qweight_final.shape[0], qweight_final.shape[1] + assert N % block_n == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low +32 high) + prefix = sizes[:-2] + new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] + out_shape = [N, K // 2] + qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) + qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) + qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] + high = qw[:, BIT_COUNT:] # high 32 + low = qw[:, :BIT_COUNT] # low 32 + packed = ((high << 4) | low).to(torch.uint8) # combine + final_qweight = packed.reshape(out_shape) + return final_qweight + + GAUDI_SW_VER = get_gaudi_sw_version() diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f590bc6ab..8124ef303 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -279,14 +279,17 @@ template void dequantizeBlockwise4bitCpu( template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); -// template void gemv_4bit_inference( -// int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, -// int lda, int ldb, int ldc, int blocksize); +template void gemv_4bit_inference( + long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); +template void gemv_4bit_inference( + long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); -// template void gemv_4bit_inference( -// int m, int n, int k, bf16_t* A, unsigned char* B, float* absmax, float* datatype, bf16_t* out, -// int lda, int ldb, int ldc, int blocksize); +template void gemv_4bit_inference( + long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); +template void gemv_4bit_inference( + long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); -// template void gemv_4bit_inference( -// int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, -// int lda, int ldb, int ldc, int blocksize); +template void gemv_4bit_inference( + long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); +template void gemv_4bit_inference( + long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index fea894d79..cb09f6a81 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -163,4 +163,7 @@ void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absm template void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +template +void gemv_4bit_inference(long long m, long long n, long long k, T* x, unsigned char* w, const float* absmax, T* out, long long blocksize, long long x_stride, long long out_stride); + #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index d9914951f..ae8ecd98d 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -889,4 +889,36 @@ void cdequantize_blockwise_cpu_nf4_fp16( ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + +void gemv_4bit_inference_cpu_fp4_fp32( + long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride +) { + gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +} +void gemv_4bit_inference_cpu_fp4_fp16( + long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride +) { + gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +} +void gemv_4bit_inference_cpu_fp4_bf16( + long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride +) { + gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +} + +void gemv_4bit_inference_cpu_nf4_fp32( + long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride +) { + gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +} +void gemv_4bit_inference_cpu_nf4_fp16( + long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride +) { + gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +} +void gemv_4bit_inference_cpu_nf4_bf16( + long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride +) { + gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +} } From f784be866415ae226629238b2263656740bcb476 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 08:54:04 +0000 Subject: [PATCH 43/78] enable gemv Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 44 ++++ bitsandbytes/backends/utils.py | 34 +-- bitsandbytes/nn/modules.py | 11 +- bitsandbytes/utils.py | 33 +++ csrc/cpu_ops.cpp | 382 ++++++++++++++++++++++++++++++- csrc/cpu_ops.h | 76 +++++- csrc/pythonInterface.cpp | 44 ++-- 7 files changed, 555 insertions(+), 69 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a716c7580..a32081a3d 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -207,6 +207,50 @@ def _( return out + if hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16"): + @register_kernel("bitsandbytes::gemv_4bit", "cpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + ) -> torch.Tensor: + # Applied from dequantize_4bit + dtype = A.dtype + quant_type = "fp4" if code[1] > 0 else "nf4" + # cpu fused op only support bf16 for now. + if dtype != torch.bfloat16: + A = A.to(torch.bfloat16) + + out_shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(out_shape, dtype=A.dtype, device=A.device) + if quant_type == "fp4": + lib.cdequantize_blockwise_cpu_fp4_bf16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif quant_type == "nf4": + lib.cdequantize_blockwise_cpu_nf4_bf16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + + if dtype != torch.bfloat16: + out = out.to(dtype) + + return out + + def dequantize_nf4_test( A: torch.Tensor, absmax: torch.Tensor, diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index 0d7b5cb4d..1d6abd5c1 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -1,6 +1,7 @@ import subprocess from packaging import version +from collections.abc import Sequence import torch try: @@ -81,37 +82,4 @@ def get_gaudi_sw_version(): return version.parse(output.stdout.split("\n")[0].split()[-1]) -def convert_weight_packed_for_cpu(qweight: torch.Tensor, - scales: torch.Tensor, - block_n: int = 32): - """ - qweight: (K * N / 2) uint8 - return: packed_weight - """ - assert qweight.dtype == torch.uint8, "qweight must be uint8" - qweight = qweight.reshape(-1) - unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=A.device) - unpacked_w[1::2] = qweight & 0xF - unpacked_w[::2] = qweight >> 4 - qweight_final = unpacked_w.reshape(shape).transpose(-1, -2).to(torch.uint8) # (*, N, K) - # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit - assert len(qweight_final.shape) == 2 - N, K = qweight_final.shape[0], qweight_final.shape[1] - assert N % block_n == 0, "N must be divisible by block_n" - assert K % 2 == 0, "K must be even" - BLOCK_N = block_n - BIT_COUNT = 32 # (=32 low +32 high) - prefix = sizes[:-2] - new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] - out_shape = [N, K // 2] - qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) - qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) - qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] - high = qw[:, BIT_COUNT:] # high 32 - low = qw[:, :BIT_COUNT] # low 32 - packed = ((high << 4) | low).to(torch.uint8) # combine - final_qweight = packed.reshape(out_shape) - return final_qweight - - GAUDI_SW_VER = get_gaudi_sw_version() diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 79f8daf2a..76d41b081 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -14,7 +14,8 @@ from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer +from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, convert_weight_packed_for_cpu +from ..cextension import ErrorHandlerMockBNBNativeLibrary, lib T = TypeVar("T", bound="torch.nn.Module") @@ -479,6 +480,7 @@ def __init__( self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage + self.enable_optimized_cpu = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -512,8 +514,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[prefix + "weight." + k] = v if keep_vars else v.detach() def forward(self, x: torch.Tensor): + quant_state = self.weight.quant_state fix_4bit_weight_quant_state_from_module(self) + if not self.enable_optimized_cpu and not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) and hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16"): + self.weight.data, quant_state.absmax = convert_weight_packed_for_cpu(self.weight.data, quant_state.absmax, quant_state.shape) + self.enable_optimized_cpu = True + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) @@ -529,7 +536,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) weight = self.weight.t() - return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) class LinearFP4(Linear4bit): diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 98ccd7da6..2d76fe4ca 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -203,3 +203,36 @@ def sync_gpu(t: torch.Tensor): torch.cuda.synchronize() elif t.device.type == "xpu": torch.xpu.synchronize() + +def convert_weight_packed_for_cpu(qweight: torch.Tensor, + scales: torch.Tensor, + shape: Sequence[int], + block_n: int = 32): + """ + qweight: (K * N / 2) uint8 + return: packed_weight + """ + assert qweight.dtype == torch.uint8, "qweight must be uint8" + qweight = qweight.reshape(-1) + unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=A.device) + unpacked_w[1::2] = qweight & 0xF + unpacked_w[::2] = qweight >> 4 + qweight_final = unpacked_w.reshape(shape).to(torch.uint8) # (*, N, K) + # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit + assert len(qweight_final.shape) == 2 + N, K = qweight_final.shape[0], qweight_final.shape[1] + assert N % block_n == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low +32 high) + prefix = sizes[:-2] + new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] + out_shape = [N, K // 2] + qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) + qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) + qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] + high = qw[:, BIT_COUNT:] # high 32 + low = qw[:, :BIT_COUNT] # low 32 + packed = ((high << 4) | low).to(torch.uint8) # combine + final_qweight = packed.reshape(out_shape) + return final_qweight, scales.T diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 8124ef303..1282f55fa 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -253,6 +253,368 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long } } + +#if true or defined(__AVX512F__) && defined(__AVX512BF16__) + +#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +// template +// struct load_dequant_zp_only_4bit<32, sym_quant> { +// #if defined(CPU_CAPABILITY_AVX512) +// static inline std::array<__m512, 2> call( +// uint8_t* p, +// __m512 lut, +// std::array<__m512, 2> vzps) { +// using T = float; +// using VA = VecArray<32, T>; +// using VAT = typename VA::type; +// constexpr long COLS = VA::num_vec; +// auto packed = _mm_loadu_si128((__m128i*)p); +// __m512i int32[COLS]; +// { +// auto low_4bit = _mm512_cvtepu8_epi32(packed); +// auto high_4bit = _mm512_srli_epi32(low_4bit, 4); +// int32[0] = low_4bit; +// int32[1] = high_4bit; +// } +// VAT vbs; +// compile_time_for::op([&](auto idx) { +// vbs[idx] = _mm512_permutexvar_ps(int32[idx], lut); +// if constexpr (!sym_quant) { +// vbs[idx] = _mm512_sub_ps(vbs[idx], vzps[idx]); +// } +// }); +// return vbs; +// } +// #endif + +template +struct tinygemm_kernel_nn { + static inline void apply( + const bf16_t* __restrict__ A, + const unsigned char* __restrict__ B, + bf16_t* __restrict__ C, + bf16_t* __restrict__ Bs, + int64_t K, + int group_size, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t strideBz, + int64_t strideBs) { + static_assert(BLOCK_N % 32 == 0); + constexpr int ROWS = BLOCK_M; // 32 + constexpr int COLS = BLOCK_N / 16; // 2 + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 16 * 4; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vc_master[ROWS * COLS]; + + __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit + // w and z are in [0,15], hence (w-z) is in [-15,15] + // we will add 15 to it to shift it to [0,30] for lookup table indexing + __m256i fifteen = _mm256_set1_epi8(15); + __m512i bf16_lut = _mm512_set_epi16( + 0x0000, + 0x4170, + 0x4160, + 0x4150, + 0x4140, + 0x4130, + 0x4120, + 0x4110, + 0x4100, + 0x40E0, + 0x40C0, + 0x40A0, + 0x4080, + 0x4040, + 0x4000, + 0x3F80, + 0x0000, + -0x4080, + -0x4000, + -0x3FC0, + -0x3F80, + -0x3F60, + -0x3F40, + -0x3F20, + -0x3F00, + -0x3EF0, + -0x3EE0, + -0x3ED0, + -0x3EC0, + -0x3EB0, + -0x3EA0, + -0x3E90); + __m512 scales[COLS]; + // repeat interleave + __m256i idx1 = _mm256_set_epi8( + 31, + 31, + 30, + 30, + 29, + 29, + 28, + 28, + 27, + 27, + 26, + 26, + 25, + 25, + 24, + 24, + 23, + 23, + 22, + 22, + 21, + 21, + 20, + 20, + 19, + 19, + 18, + 18, + 17, + 17, + 16, + 16); + __m256i idx0 = _mm256_set_epi8( + 15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const int64_t gs2 = group_size >> 1; // 64 / 2 = 32 + const float* a_ptr = reinterpret_cast(A); + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + vc_master[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + // x * ((w - zeros) * scales) + // = (x * (w - zeros)) * scales + + auto pre_compute = [&](auto i, int64_t kgs) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_set1_ps(0.f); // reset accumulator + + // load zeros and scales + if constexpr (row == 0 && col % 2 == 0) { + // Bz layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=uint8 + __m256i tmp = _mm256_loadu_si256(reinterpret_cast(Bz + kgs * strideBz + col * 16)); + // (w - (z - 15)) = (w - z + 15) + tmp = _mm256_sub_epi8(tmp, fifteen); + zeros[col] = _mm256_permutexvar_epi8(idx0, tmp); + zeros[col + 1] = _mm256_permutexvar_epi8(idx1, tmp); + + // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 + __m512i tmp2 = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); + scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp2, 0)); + scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp2, 1)); + } + }; + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0 && col % 2 == 0) { + __m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast(B + k * ldb + col * 16)); + + // deinterleave and lookup to BF16 + __m256i vb_i8_lo = vb_u4 & mask; + __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; + vb_i8_lo = _mm256_sub_epi8(vb_i8_lo, zeros[col]); + vb_i8_hi = _mm256_sub_epi8(vb_i8_hi, zeros[col + 1]); + vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), bf16_lut); + vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), bf16_lut); + + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + auto post_compute = [&](auto i, int64_t kgs) { + vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]); + }; + for (int64_t k = 0; k < K2; k += gs2) { + Unroll{}(pre_compute, k / gs2); + for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) { + Unroll{}(compute, k + k_offset); + } + Unroll{}(post_compute, k / gs2); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>(C + row * ldc + col * 16), + (__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, \ + B + nb_start, \ + C + mb_start * ldc + nb_start, \ + Bs + nb_start, \ + K, \ + group_size, \ + lda, \ + ldb, \ + ldc, \ + strideBz, \ + strideBs); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const unsigned char* __restrict__ B, + scalar_t* __restrict__ C, + const scalar_t* __restrict__ Bs, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int group_size, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t strideBz, + int64_t strideBs) { + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64); + break; + default: { + std::fprintf(stderr, + "[bitsandbytes] Unexpected block size %lldx%lld\n", + (long long)mb_size, + (long long)nb_size); + std::abort(); // or return; if you prefer silent exit + } + } + } + } +} + +template +void gemv_4bit_inference(long long M, + long long N, + long long K, + T* x, + unsigned char* w, + const float* absmax, + T* out, + long long blocksize, + long long x_stride, + long long out_stride) { + constexpr int64_t BLOCK_M = block_size_m(); // 32 + constexpr int64_t BLOCK_N = block_size_n(); // 32 + const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 + const int64_t NB = div_up(N, BLOCK_N); + // TODO: enable brgemm in the future. + // const bool use_brgemm = M > 4; + // const bool use_brgemm_dequant_out = M > 512; + scalar_t* Btmp_start = nullptr; + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N * K); + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + alignas(64) scalar_t Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128 + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { // 0-1 + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + int64_t mb_start = mb * BLOCK_M; // 0 + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + tinygemm_kernel( + /* A */ x + mb_start * mat1_strideM, + /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 + /* C */ out + mb_start * out_strideM + nb_start, + /* Bs */ absmax + nb_start, + /* Btmp */ Btmp_inner, + /* Ctmp */ Ctmp, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* gs */ group_size, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* sBz */ N, + /* sBs */ N); + } + } + } + // if (use_brgemm) { + // at::native::cpublas::brgemm_release(); + // } + }); +} +#endif + + //============================================================== // TEMPLATE DEFINITIONS //============================================================== @@ -279,17 +641,17 @@ template void dequantizeBlockwise4bitCpu( template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); -template void gemv_4bit_inference( - long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); -template void gemv_4bit_inference( - long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); +// template void gemv_4bit_inference( +// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); +// template void gemv_4bit_inference( +// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); -template void gemv_4bit_inference( - long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); -template void gemv_4bit_inference( - long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); +// template void gemv_4bit_inference( +// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); +// template void gemv_4bit_inference( +// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); template void gemv_4bit_inference( - long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); + long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); template void gemv_4bit_inference( - long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); + long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index cb09f6a81..ae4e98af5 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -4,6 +4,76 @@ #include #include +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 + +// block size for AMX gemm +constexpr int block_size_m() { + return 2 * TILE_M; +} +constexpr int block_size_n() { + return 2 * TILE_N; +} + +template +inline int get_cache_blocks(int chunk_size) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) + { + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); + } +#else + f(0, m, 0, n); +#endif +} + + void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); typedef enum DataType_t { @@ -163,7 +233,9 @@ void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absm template void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); -template -void gemv_4bit_inference(long long m, long long n, long long k, T* x, unsigned char* w, const float* absmax, T* out, long long blocksize, long long x_stride, long long out_stride); +#if defined(__AVX512F__) && defined(__AVX512BF16__) + template + void gemv_4bit_inference(long long M, long long N, long long K, T* x, unsigned char* w, const float* absmax, T* out, long long blocksize, long long x_stride, long long out_stride); +#endif #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index ae8ecd98d..74479376f 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -890,34 +890,34 @@ void cdequantize_blockwise_cpu_nf4_fp16( dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } -void gemv_4bit_inference_cpu_fp4_fp32( - long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride -) { - gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); -} -void gemv_4bit_inference_cpu_fp4_fp16( - long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride -) { - gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); -} +// void gemv_4bit_inference_cpu_fp4_fp32( +// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride +// ) { +// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// } +// void gemv_4bit_inference_cpu_fp4_fp16( +// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride +// ) { +// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// } void gemv_4bit_inference_cpu_fp4_bf16( - long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride + long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride ) { gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); } -void gemv_4bit_inference_cpu_nf4_fp32( - long long m, long long n, long long k, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride -) { - gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); -} -void gemv_4bit_inference_cpu_nf4_fp16( - long long m, long long n, long long k, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride -) { - gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); -} +// void gemv_4bit_inference_cpu_nf4_fp32( +// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride +// ) { +// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// } +// void gemv_4bit_inference_cpu_nf4_fp16( +// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride +// ) { +// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// } void gemv_4bit_inference_cpu_nf4_bf16( - long long m, long long n, long long k, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride + long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride ) { gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); } From 92192c9f4b91fe62e032cf0eb9c6538499acd640 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 10:24:44 +0000 Subject: [PATCH 44/78] change to 4bit dequant Signed-off-by: jiqing-feng --- bitsandbytes/utils.py | 2 +- csrc/cpu_ops.cpp | 173 +++++++++--------------------------------- 2 files changed, 37 insertions(+), 138 deletions(-) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 2d76fe4ca..0d0fb6be9 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -235,4 +235,4 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, low = qw[:, :BIT_COUNT] # low 32 packed = ((high << 4) | low).to(torch.uint8) # combine final_qweight = packed.reshape(out_shape) - return final_qweight, scales.T + return final_qweight, scales.T.to(torch.bfloat16) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 1282f55fa..97402f290 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -254,42 +254,12 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long } -#if true or defined(__AVX512F__) && defined(__AVX512BF16__) +#if true // || defined(__AVX512F__) && defined(__AVX512BF16__) #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) -// template -// struct load_dequant_zp_only_4bit<32, sym_quant> { -// #if defined(CPU_CAPABILITY_AVX512) -// static inline std::array<__m512, 2> call( -// uint8_t* p, -// __m512 lut, -// std::array<__m512, 2> vzps) { -// using T = float; -// using VA = VecArray<32, T>; -// using VAT = typename VA::type; -// constexpr long COLS = VA::num_vec; -// auto packed = _mm_loadu_si128((__m128i*)p); -// __m512i int32[COLS]; -// { -// auto low_4bit = _mm512_cvtepu8_epi32(packed); -// auto high_4bit = _mm512_srli_epi32(low_4bit, 4); -// int32[0] = low_4bit; -// int32[1] = high_4bit; -// } -// VAT vbs; -// compile_time_for::op([&](auto idx) { -// vbs[idx] = _mm512_permutexvar_ps(int32[idx], lut); -// if constexpr (!sym_quant) { -// vbs[idx] = _mm512_sub_ps(vbs[idx], vzps[idx]); -// } -// }); -// return vbs; -// } -// #endif - -template -struct tinygemm_kernel_nn { +template +struct tinygemm_kernel_nn { static inline void apply( const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, @@ -315,80 +285,22 @@ struct tinygemm_kernel_nn { __m512 vc_master[ROWS * COLS]; __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit - // w and z are in [0,15], hence (w-z) is in [-15,15] - // we will add 15 to it to shift it to [0,30] for lookup table indexing - __m256i fifteen = _mm256_set1_epi8(15); - __m512i bf16_lut = _mm512_set_epi16( - 0x0000, - 0x4170, - 0x4160, - 0x4150, - 0x4140, - 0x4130, - 0x4120, - 0x4110, - 0x4100, - 0x40E0, - 0x40C0, - 0x40A0, - 0x4080, - 0x4040, - 0x4000, - 0x3F80, - 0x0000, - -0x4080, - -0x4000, - -0x3FC0, - -0x3F80, - -0x3F60, - -0x3F40, - -0x3F20, - -0x3F00, - -0x3EF0, - -0x3EE0, - -0x3ED0, - -0x3EC0, - -0x3EB0, - -0x3EA0, - -0x3E90); + __m512i lut = DATA_TYPE == 1 ? _mm512_set_epi16( + /* e31..e16 copy e15..e0 */ + 0x0000, 0x3A45, 0x3F30, 0x3EAA, 0x3F80, 0x3F2A, 0x3B4F, 0x3E80, + 0x0000, 0xBA45, 0xBF30, 0xBEAA, 0xBF80, 0xBF2A, 0xBB4F, 0xBE80, + /* e15..e0 original index 0..15 */ + 0x0000, 0x3A45, 0x3F30, 0x3EAA, 0x3F80, 0x3F2A, 0x3B4F, 0x3E80, + 0x0000, 0xBA45, 0xBF30, 0xBEAA, 0xBF80, 0xBF2A, 0xBB4F, 0xBE80 + ) : _mm512_set_epi16( + /* e31..e16 copy e15..e0 */ + 0xBF80, 0xBFA5, 0xBF0C, 0xBECA, 0xBE84, 0xBE1C, 0xBDA4, 0xBD28, + 0x0000, 0x3D2E, 0x3D5F, 0x3DAE, 0x3DF0, 0x3E0C, 0x3E38, 0x3F80, + /* e15..e0 original index 0..15 */ + 0xBF80, 0xBFA5, 0xBF0C, 0xBECA, 0xBE84, 0xBE1C, 0xBDA4, 0xBD28, + 0x0000, 0x3D2E, 0x3D5F, 0x3DAE, 0x3DF0, 0x3E0C, 0x3E38, 0x3F80 + ); __m512 scales[COLS]; - // repeat interleave - __m256i idx1 = _mm256_set_epi8( - 31, - 31, - 30, - 30, - 29, - 29, - 28, - 28, - 27, - 27, - 26, - 26, - 25, - 25, - 24, - 24, - 23, - 23, - 22, - 22, - 21, - 21, - 20, - 20, - 19, - 19, - 18, - 18, - 17, - 17, - 16, - 16); - __m256i idx0 = _mm256_set_epi8( - 15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); - const int64_t K2 = K >> 1; const int64_t lda2 = lda >> 1; const int64_t ldb2 = ldb; // ldb * 2 >> 1; @@ -401,27 +313,17 @@ struct tinygemm_kernel_nn { }; Unroll{}(loadc); - // x * ((w - zeros) * scales) - // = (x * (w - zeros)) * scales - auto pre_compute = [&](auto i, int64_t kgs) { constexpr int row = i / COLS; constexpr int col = i % COLS; vc[i] = _mm512_set1_ps(0.f); // reset accumulator - // load zeros and scales + // load scales if constexpr (row == 0 && col % 2 == 0) { - // Bz layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=uint8 - __m256i tmp = _mm256_loadu_si256(reinterpret_cast(Bz + kgs * strideBz + col * 16)); - // (w - (z - 15)) = (w - z + 15) - tmp = _mm256_sub_epi8(tmp, fifteen); - zeros[col] = _mm256_permutexvar_epi8(idx0, tmp); - zeros[col + 1] = _mm256_permutexvar_epi8(idx1, tmp); - // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 - __m512i tmp2 = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); - scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp2, 0)); - scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp2, 1)); + __m512i tmp = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); + scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0)); + scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1)); } }; auto compute = [&](auto i, int64_t k) { @@ -437,10 +339,8 @@ struct tinygemm_kernel_nn { // deinterleave and lookup to BF16 __m256i vb_i8_lo = vb_u4 & mask; __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; - vb_i8_lo = _mm256_sub_epi8(vb_i8_lo, zeros[col]); - vb_i8_hi = _mm256_sub_epi8(vb_i8_hi, zeros[col + 1]); - vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), bf16_lut); - vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), bf16_lut); + vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); + vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); if constexpr (PREFETCH_SIZE_K > 0) { _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); @@ -471,10 +371,9 @@ struct tinygemm_kernel_nn { Unroll{}(storec); } }; -#endif -#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ - tinygemm_kernel_nn::apply( \ +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \ + tinygemm_kernel_nn::apply( \ A + mb_start * lda, \ B + nb_start, \ C + mb_start * ldc + nb_start, \ @@ -487,7 +386,7 @@ struct tinygemm_kernel_nn { strideBz, \ strideBs); -template +template void tinygemm_kernel( const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, @@ -518,31 +417,31 @@ void tinygemm_kernel( switch (mb_size << 4 | nb_size >> 4) { // mb_size = 1 case 0x12: - LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + LAUNCH_TINYGEMM_KERNEL_NN(1, 32, DATA_TYPE); break; case 0x14: - LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + LAUNCH_TINYGEMM_KERNEL_NN(1, 64, DATA_TYPE); break; // mb_size = 2 case 0x22: - LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + LAUNCH_TINYGEMM_KERNEL_NN(2, 32, DATA_TYPE); break; case 0x24: - LAUNCH_TINYGEMM_KERNEL_NN(2, 64); + LAUNCH_TINYGEMM_KERNEL_NN(2, 64, DATA_TYPE); break; // mb_size = 3 case 0x32: - LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + LAUNCH_TINYGEMM_KERNEL_NN(3, 32, DATA_TYPE); break; case 0x34: - LAUNCH_TINYGEMM_KERNEL_NN(3, 64); + LAUNCH_TINYGEMM_KERNEL_NN(3, 64, DATA_TYPE); break; // mb_size = 4 case 0x42: - LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + LAUNCH_TINYGEMM_KERNEL_NN(4, 32, DATA_TYPE); break; case 0x44: - LAUNCH_TINYGEMM_KERNEL_NN(4, 64); + LAUNCH_TINYGEMM_KERNEL_NN(4, 64, DATA_TYPE); break; default: { std::fprintf(stderr, @@ -588,7 +487,7 @@ void gemv_4bit_inference(long long M, int64_t mb_size = std::min(M - mb_start, BLOCK_M); int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(N - nb_start, BLOCK_N); - tinygemm_kernel( + tinygemm_kernel( /* A */ x + mb_start * mat1_strideM, /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 /* C */ out + mb_start * out_strideM + nb_start, From bd02e71246db48571d4f0050469d5079c89d4ad1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 10:40:43 +0000 Subject: [PATCH 45/78] fix def Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 13 +++++++++++++ csrc/cpu_ops.h | 50 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 97402f290..bf0a6beda 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -258,6 +258,19 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t*, + const unsigned char*, + scalar_t*, + const scalar_t*, + int64_t, int, int64_t, int64_t, int64_t, int64_t, int64_t) { + static_assert(sizeof(scalar_t) == 0, + "tinygemm_kernel_nn primary template should never be instantiated"); + } +}; + template struct tinygemm_kernel_nn { static inline void apply( diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index ae4e98af5..ce83b04d4 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -4,6 +4,10 @@ #include #include +#if defined(_OPENMP) + #include +#else + // amx-bf16 #define TILE_M 16 #define TILE_N 16 @@ -24,6 +28,52 @@ inline int get_cache_blocks(int chunk_size) { return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); } +// forced unroll for perf critical path +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +template ::value, int>::type = 0> +inline T div_up(T x, T y) { + return (x + y - 1) / y; +} + +inline int get_max_threads() { +#if defined(_OPENMP) + return omp_get_max_threads(); +#else + unsigned hc = std::thread::hardware_concurrency(); + return hc == 0 ? 1 : int(hc); +#endif +} + +int inline adjust_num_threads(int m) { + int actual_nth = get_max_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + template inline void parallel_2d(int m, int n, const func_t& f) { // make sure we have even num_threads From 852006919ea5fe37cbe05921ca530fb0bcb9832f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 11:21:18 +0000 Subject: [PATCH 46/78] fix type Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 26 ++++++++++++++++----- csrc/cpu_ops.cpp | 40 ++++++++++++++++---------------- csrc/cpu_ops.h | 26 +++++++++------------ csrc/pythonInterface.cpp | 27 +++++++++++---------- 4 files changed, 66 insertions(+), 53 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a32081a3d..3e01924fd 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -224,25 +224,39 @@ def _( if dtype != torch.bfloat16: A = A.to(torch.bfloat16) + A = A.reshape(-1, shapeB[1] // 2) out_shape = (*A.shape[:-1], shapeB[0]) out = torch.empty(out_shape, dtype=A.dtype, device=A.device) + M = A.shape(0) + N = shapeB[0] + K = A.shape[1] + x_strideM = A.stride(0) + out_strideM = out.stride(0) if quant_type == "fp4": lib.cdequantize_blockwise_cpu_fp4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), get_ptr(A), + get_ptr(B), get_ptr(absmax), get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), ) elif quant_type == "nf4": lib.cdequantize_blockwise_cpu_nf4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), get_ptr(A), + get_ptr(B), get_ptr(absmax), get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), ) if dtype != torch.bfloat16: diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index bf0a6beda..402fc7b8e 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -469,16 +469,16 @@ void tinygemm_kernel( } template -void gemv_4bit_inference(long long M, - long long N, - long long K, +void gemv_4bit_inference(int64_t M, + int64_t N, + int64_t K, T* x, unsigned char* w, const float* absmax, T* out, - long long blocksize, - long long x_stride, - long long out_stride) { + int64_t blocksize, + int64_t x_stride, + int64_t out_stride) { constexpr int64_t BLOCK_M = block_size_m(); // 32 constexpr int64_t BLOCK_N = block_size_n(); // 32 const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 @@ -486,13 +486,13 @@ void gemv_4bit_inference(long long M, // TODO: enable brgemm in the future. // const bool use_brgemm = M > 4; // const bool use_brgemm_dequant_out = M > 512; - scalar_t* Btmp_start = nullptr; + // T* Btmp_start = nullptr; // l2 cache block for n int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N * K); parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { // for brgemm, use float32 for accumulate alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - alignas(64) scalar_t Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128 + alignas(64) T Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128 for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { for (int64_t mb = begin_mb; mb < end_mb; ++mb) { // 0-1 for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { @@ -500,20 +500,20 @@ void gemv_4bit_inference(long long M, int64_t mb_size = std::min(M - mb_start, BLOCK_M); int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(N - nb_start, BLOCK_N); - tinygemm_kernel( - /* A */ x + mb_start * mat1_strideM, + tinygemm_kernel( + /* A */ x + mb_start * x_stride, /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 - /* C */ out + mb_start * out_strideM + nb_start, + /* C */ out + mb_start * out_stride + nb_start, /* Bs */ absmax + nb_start, /* Btmp */ Btmp_inner, /* Ctmp */ Ctmp, /* M */ mb_size, /* N */ nb_size, /* K */ K, - /* gs */ group_size, - /* lda */ mat1_strideM, + /* gs */ blocksize, // group_size + /* lda */ x_stride, /* ldb */ nb_size, - /* ldc */ out_strideM, + /* ldc */ out_stride, /* sBz */ N, /* sBs */ N); } @@ -554,16 +554,16 @@ template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); // template void gemv_4bit_inference( -// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); +// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride); +// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); +// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride); +// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride); + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index ce83b04d4..3f342a265 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,10 +3,14 @@ #include #include +#include +#include +#include +#include #if defined(_OPENMP) #include -#else +#endif // amx-bf16 #define TILE_M 16 @@ -14,12 +18,8 @@ #define TILE_K 32 // block size for AMX gemm -constexpr int block_size_m() { - return 2 * TILE_M; -} -constexpr int block_size_n() { - return 2 * TILE_N; -} +constexpr int block_size_m() { return 2 * TILE_M; } +constexpr int block_size_n() { return 2 * TILE_N; } template inline int get_cache_blocks(int chunk_size) { @@ -53,9 +53,7 @@ struct Unroll<1> { }; template ::value, int>::type = 0> -inline T div_up(T x, T y) { - return (x + y - 1) / y; -} +inline T div_up(T x, T y) { return (x + y - 1) / y; } inline int get_max_threads() { #if defined(_OPENMP) @@ -66,11 +64,9 @@ inline int get_max_threads() { #endif } -int inline adjust_num_threads(int m) { +inline int adjust_num_threads(int m) { int actual_nth = get_max_threads(); - if (m == 1) { - return actual_nth; - } + if (m == 1) return actual_nth; return std::max(1, (actual_nth >> 1) * 2); } @@ -285,7 +281,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, l #if defined(__AVX512F__) && defined(__AVX512BF16__) template - void gemv_4bit_inference(long long M, long long N, long long K, T* x, unsigned char* w, const float* absmax, T* out, long long blocksize, long long x_stride, long long out_stride); + void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, T* x, unsigned char* w, const float* absmax, T* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); #endif #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 74479376f..0991d19a9 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -890,35 +890,38 @@ void cdequantize_blockwise_cpu_nf4_fp16( dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } +#if defined(__AVX512F__) && defined(__AVX512BF16__) // void gemv_4bit_inference_cpu_fp4_fp32( -// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride +// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { -// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } // void gemv_4bit_inference_cpu_fp4_fp16( -// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride +// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { -// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } void gemv_4bit_inference_cpu_fp4_bf16( - long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride ) { - gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); + gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } // void gemv_4bit_inference_cpu_nf4_fp32( -// long long M, long long N, long long K, float* x, unsigned char* w, const float* absmax, float* out, long long blocksize, long long x_stride, long long out_stride +// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { -// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } // void gemv_4bit_inference_cpu_nf4_fp16( -// long long M, long long N, long long K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, long long blocksize, long long x_stride, long long out_stride +// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { -// gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); +// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } void gemv_4bit_inference_cpu_nf4_bf16( - long long M, long long N, long long K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, long long blocksize, long long x_stride, long long out_stride + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, + int64_t blocksize, int64_t x_stride, int64_t out_stride ) { - gemv_4bit_inference(m, n, k, x, w, absmax, out, blocksize); + gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } +#endif } From e921cbb5d9b4994cb29b29e1b3d77f4a361c9eb0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 12:13:52 +0000 Subject: [PATCH 47/78] fix absmax dtype Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 +++--- csrc/cpu_ops.h | 4 +++- csrc/pythonInterface.cpp | 8 ++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 402fc7b8e..9dab93fa2 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -474,7 +474,7 @@ void gemv_4bit_inference(int64_t M, int64_t K, T* x, unsigned char* w, - const float* absmax, + const T* absmax, T* out, int64_t blocksize, int64_t x_stride, @@ -564,6 +564,6 @@ template void dequantizeBlockwise4bitCpu( // int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 3f342a265..ad6771eb8 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -16,6 +16,8 @@ #define TILE_M 16 #define TILE_N 16 #define TILE_K 32 +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K // block size for AMX gemm constexpr int block_size_m() { return 2 * TILE_M; } @@ -281,7 +283,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, l #if defined(__AVX512F__) && defined(__AVX512BF16__) template - void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, T* x, unsigned char* w, const float* absmax, T* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, T* x, unsigned char* w, const T* absmax, T* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); #endif #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 0991d19a9..a88da2612 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -897,12 +897,12 @@ void cdequantize_blockwise_cpu_nf4_fp16( // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } // void gemv_4bit_inference_cpu_fp4_fp16( -// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride +// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const fp16_t* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } void gemv_4bit_inference_cpu_fp4_bf16( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } @@ -913,12 +913,12 @@ void gemv_4bit_inference_cpu_fp4_bf16( // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } // void gemv_4bit_inference_cpu_nf4_fp16( -// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride +// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const fp16_t* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } void gemv_4bit_inference_cpu_nf4_bf16( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const float* absmax, bf16_t* out, + int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); From 9b5d97a38e1766627b26a9fb5d96876901436c56 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 12:29:37 +0000 Subject: [PATCH 48/78] fix type Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 25 ++++++++++++------------- csrc/pythonInterface.cpp | 29 +++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 9dab93fa2..ea3137eb4 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -277,7 +277,7 @@ struct tinygemm_kernel_nn { const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C, - bf16_t* __restrict__ Bs, + const bf16_t* __restrict__ Bs, int64_t K, int group_size, int64_t lda, @@ -472,10 +472,10 @@ template void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, - T* x, - unsigned char* w, - const T* absmax, - T* out, + const T* __restrict__ x, + const unsigned char* __restrict__ w, + const T* __restrict__ absmax, + T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride) { @@ -554,16 +554,15 @@ template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); - +// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const float* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); - +// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); \ No newline at end of file diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index a88da2612..616db6e64 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -892,33 +892,50 @@ void cdequantize_blockwise_cpu_nf4_fp16( #if defined(__AVX512F__) && defined(__AVX512BF16__) // void gemv_4bit_inference_cpu_fp4_fp32( -// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride +// int64_t M, int64_t N, int64_t K, +// const float* __restrict__ x, const unsigned char* __restrict__ w, +// const float* __restrict__ absmax, float* __restrict__ out, +// int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } // void gemv_4bit_inference_cpu_fp4_fp16( -// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const fp16_t* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride +// int64_t M, int64_t N, int64_t K, +// const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, +// const fp16_t* __restrict__ absmax, fp16_t* __restrict__ out, +// int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } void gemv_4bit_inference_cpu_fp4_bf16( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride + int64_t M, int64_t N, int64_t K, + const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, + int64_t blocksize, int64_t x_stride, int64_t out_stride ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } // void gemv_4bit_inference_cpu_nf4_fp32( -// int64_t M, int64_t N, int64_t K, float* x, unsigned char* w, const float* absmax, float* out, int64_t blocksize, int64_t x_stride, int64_t out_stride +// int64_t M, int64_t N, int64_t K, +// const float* __restrict__ x, const unsigned char* __restrict__ w, +// const float* __restrict__ absmax, float* __restrict__ out, +// int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } // void gemv_4bit_inference_cpu_nf4_fp16( -// int64_t M, int64_t N, int64_t K, fp16_t* x, unsigned char* w, const fp16_t* absmax, fp16_t* out, int64_t blocksize, int64_t x_stride, int64_t out_stride +// int64_t M, int64_t N, int64_t K, +// const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, +// const fp16_t* __restrict__ absmax, fp16_t* __restrict__ out, +// int64_t blocksize, int64_t x_stride, int64_t out_stride // ) { // gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); // } void gemv_4bit_inference_cpu_nf4_bf16( - int64_t M, int64_t N, int64_t K, bf16_t* x, unsigned char* w, const bf16_t* absmax, bf16_t* out, + int64_t M, int64_t N, int64_t K, + const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); From fd6cff130f320510cbf0122caaf4b754ad645a37 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 12:36:30 +0000 Subject: [PATCH 49/78] fix compile and type Signed-off-by: jiqing-feng --- CMakeLists.txt | 19 ++++++++++++++----- csrc/cpu_ops.h | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d4a492c8..3a351b575 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -272,16 +272,25 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if (BUILD_CPU) target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) include(CheckCXXCompilerFlag) - check_cxx_compiler_flag(-mavx512f HAS_AVX512F) check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16) - + check_cxx_compiler_flag(-mavx512dq HAS_AVX512DQ) + check_cxx_compiler_flag(-mavx512bw HAS_AVX512BW) + check_cxx_compiler_flag(-mavx512vl HAS_AVX512VL) if(HAS_AVX512F) - target_compile_options(bitsandbytes PRIVATE -mavx512f) + target_compile_options(bitsandbytes PRIVATE -mavx512f) endif() - if(HAS_AVX512BF16) - target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + endif() + if(HAS_AVX512DQ) + target_compile_options(bitsandbytes PRIVATE -mavx512dq) + endif() + if(HAS_AVX512BW) + target_compile_options(bitsandbytes PRIVATE -mavx512bw) + endif() + if(HAS_AVX512VL) + target_compile_options(bitsandbytes PRIVATE -mavx512vl) endif() endif() diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index ad6771eb8..9aa401e4d 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -283,7 +283,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, l #if defined(__AVX512F__) && defined(__AVX512BF16__) template - void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, T* x, unsigned char* w, const T* absmax, T* out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); #endif #endif From 46d6e47a501cc7b1260ddd38976cf7c904bdccb4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 12:57:19 +0000 Subject: [PATCH 50/78] enable gemv Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 6 +++++ bitsandbytes/functional.py | 41 +++++++++++++++++++++++++++++ bitsandbytes/nn/modules.py | 9 ++++--- bitsandbytes/utils.py | 35 +----------------------- 4 files changed, 53 insertions(+), 38 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 460e3a507..df778c84f 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -437,6 +437,12 @@ def matmul_4bit( if A.device.type == "cpu": quant_state.dtype = A.dtype + if getattr(quant_state, "enable_optimized_cpu", False): + out = F.gemv_4bit(A, B, out, state=quant_state) + if bias is not None: + out += bias + return out + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3d11276ad..9e488b104 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2199,4 +2199,45 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): return out +def convert_weight_packed_for_cpu(qweight: torch.Tensor, + quant_state: QuantState, + block_n: int = 32): + """ + qweight: (K * N / 2) uint8 + return: packed_weight + """ + assert qweight.dtype == torch.uint8, "qweight must be uint8" + qweight = qweight.reshape(-1) + unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device) + unpacked_w[1::2] = qweight & 0xF + unpacked_w[::2] = qweight >> 4 + qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K) + # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit + assert len(qweight_final.shape) == 2 + N, K = qweight_final.shape[0], qweight_final.shape[1] + assert N % block_n == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low +32 high) + new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] + out_shape = [N, K // 2] + qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) + qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) + qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] + high = qw[:, BIT_COUNT:] # high 32 + low = qw[:, :BIT_COUNT] # low 32 + packed = ((high << 4) | low).to(torch.uint8) # combine + final_qweight = packed.reshape(out_shape) + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + quant_state.absmax = absmax.T.to(torch.bfloat16) + quant_state.nested = False + delattr(quant_state, "state2") + return final_qweight, quant_state + + C = 127.0 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 76d41b081..d08e43d88 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,9 +12,9 @@ import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import QuantState +from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, convert_weight_packed_for_cpu +from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer from ..cextension import ErrorHandlerMockBNBNativeLibrary, lib T = TypeVar("T", bound="torch.nn.Module") @@ -517,9 +517,10 @@ def forward(self, x: torch.Tensor): quant_state = self.weight.quant_state fix_4bit_weight_quant_state_from_module(self) - if not self.enable_optimized_cpu and not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) and hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16"): - self.weight.data, quant_state.absmax = convert_weight_packed_for_cpu(self.weight.data, quant_state.absmax, quant_state.shape) + if not self.enable_optimized_cpu and not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) and hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16") and not self.training and x.requires_grad == False: + self.weight.data, quant_state = convert_weight_packed_for_cpu(self.weight.data, quant_state) self.enable_optimized_cpu = True + setattr(quant_state, "enable_optimized_cpu", True) # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0d0fb6be9..dcb7798af 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -3,7 +3,7 @@ import subprocess import torch - +from collections.abc import Sequence def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) @@ -203,36 +203,3 @@ def sync_gpu(t: torch.Tensor): torch.cuda.synchronize() elif t.device.type == "xpu": torch.xpu.synchronize() - -def convert_weight_packed_for_cpu(qweight: torch.Tensor, - scales: torch.Tensor, - shape: Sequence[int], - block_n: int = 32): - """ - qweight: (K * N / 2) uint8 - return: packed_weight - """ - assert qweight.dtype == torch.uint8, "qweight must be uint8" - qweight = qweight.reshape(-1) - unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=A.device) - unpacked_w[1::2] = qweight & 0xF - unpacked_w[::2] = qweight >> 4 - qweight_final = unpacked_w.reshape(shape).to(torch.uint8) # (*, N, K) - # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit - assert len(qweight_final.shape) == 2 - N, K = qweight_final.shape[0], qweight_final.shape[1] - assert N % block_n == 0, "N must be divisible by block_n" - assert K % 2 == 0, "K must be even" - BLOCK_N = block_n - BIT_COUNT = 32 # (=32 low +32 high) - prefix = sizes[:-2] - new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] - out_shape = [N, K // 2] - qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) - qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) - qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] - high = qw[:, BIT_COUNT:] # high 32 - low = qw[:, :BIT_COUNT] # low 32 - packed = ((high << 4) | low).to(torch.uint8) # combine - final_qweight = packed.reshape(out_shape) - return final_qweight, scales.T.to(torch.bfloat16) From 3271c3088f09b76a44a8ef4b6b14b0b41e0955e0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 13:22:29 +0000 Subject: [PATCH 51/78] fix shape Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 ++-- bitsandbytes/functional.py | 2 +- bitsandbytes/nn/modules.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 3e01924fd..d7a650d8e 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -224,10 +224,10 @@ def _( if dtype != torch.bfloat16: A = A.to(torch.bfloat16) - A = A.reshape(-1, shapeB[1] // 2) + A = A.reshape(-1, A.shape[-1]) out_shape = (*A.shape[:-1], shapeB[0]) out = torch.empty(out_shape, dtype=A.dtype, device=A.device) - M = A.shape(0) + M = A.shape[0] N = shapeB[0] K = A.shape[1] x_strideM = A.stride(0) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9e488b104..daa965ae5 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2234,7 +2234,7 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, if absmax.dtype != torch.float32: absmax = absmax.float() - quant_state.absmax = absmax.T.to(torch.bfloat16) + quant_state.absmax = absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize).T.to(torch.bfloat16) quant_state.nested = False delattr(quant_state, "state2") return final_qweight, quant_state diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index d08e43d88..8f8d098e2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -535,7 +535,7 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - weight = self.weight.t() + weight = self.weight if getattr(quant_state, "enable_optimized_cpu", False) else self.weight.t() return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) From 176a2b61faab558d837af4d0bdf42fec01bb3c1e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 13:41:58 +0000 Subject: [PATCH 52/78] fix lib name Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d7a650d8e..19952c100 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -233,7 +233,7 @@ def _( x_strideM = A.stride(0) out_strideM = out.stride(0) if quant_type == "fp4": - lib.cdequantize_blockwise_cpu_fp4_bf16( + lib.gemv_4bit_inference_cpu_fp4_bf16( ct.c_int64(M), ct.c_int64(N), ct.c_int64(K), @@ -246,7 +246,7 @@ def _( ct.c_int64(out_strideM), ) elif quant_type == "nf4": - lib.cdequantize_blockwise_cpu_nf4_bf16( + lib.gemv_4bit_inference_cpu_nf4_bf16( ct.c_int64(M), ct.c_int64(N), ct.c_int64(K), From 196984a7cdb9ddb99e2769d9f44a756ec9114426 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 14:08:28 +0000 Subject: [PATCH 53/78] debug Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 3 +++ csrc/cpu_ops.cpp | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 19952c100..1080f4f0c 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -246,6 +246,9 @@ def _( ct.c_int64(out_strideM), ) elif quant_type == "nf4": + #print(A) + #print(B) + #print(absmax) lib.gemv_4bit_inference_cpu_nf4_bf16( ct.c_int64(M), ct.c_int64(N), diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ea3137eb4..58971ccb7 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -299,19 +299,19 @@ struct tinygemm_kernel_nn { __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit __m512i lut = DATA_TYPE == 1 ? _mm512_set_epi16( - /* e31..e16 copy e15..e0 */ - 0x0000, 0x3A45, 0x3F30, 0x3EAA, 0x3F80, 0x3F2A, 0x3B4F, 0x3E80, - 0x0000, 0xBA45, 0xBF30, 0xBEAA, 0xBF80, 0xBF2A, 0xBB4F, 0xBE80, - /* e15..e0 original index 0..15 */ - 0x0000, 0x3A45, 0x3F30, 0x3EAA, 0x3F80, 0x3F2A, 0x3B4F, 0x3E80, - 0x0000, 0xBA45, 0xBF30, 0xBEAA, 0xBF80, 0xBF2A, 0xBB4F, 0xBE80 + /* reversed: e31..e16 copy e15..e0 */ + 0xBE80, 0xBB4F, 0xBF2A, 0xBF80, 0xBEAA, 0xBF30, 0xBA45, 0x0000, + 0x3E80, 0x3B4F, 0x3F2A, 0x3F80, 0x3EAA, 0x3F30, 0x3A45, 0x0000, + /* reversed: e15..e0 original index 0..15 */ + 0xBE80, 0xBB4F, 0xBF2A, 0xBF80, 0xBEAA, 0xBF30, 0xBA45, 0x0000, + 0x3E80, 0x3B4F, 0x3F2A, 0x3F80, 0x3EAA, 0x3F30, 0x3A45, 0x0000 ) : _mm512_set_epi16( - /* e31..e16 copy e15..e0 */ - 0xBF80, 0xBFA5, 0xBF0C, 0xBECA, 0xBE84, 0xBE1C, 0xBDA4, 0xBD28, - 0x0000, 0x3D2E, 0x3D5F, 0x3DAE, 0x3DF0, 0x3E0C, 0x3E38, 0x3F80, - /* e15..e0 original index 0..15 */ - 0xBF80, 0xBFA5, 0xBF0C, 0xBECA, 0xBE84, 0xBE1C, 0xBDA4, 0xBD28, - 0x0000, 0x3D2E, 0x3D5F, 0x3DAE, 0x3DF0, 0x3E0C, 0x3E38, 0x3F80 + /* reversed: e31..e16 copy e15..e0 */ + 0x3F80, 0x3E38, 0x3E0C, 0x3DF0, 0x3DAE, 0x3D5F, 0x3D2E, 0x0000, + 0xBD28, 0xBDA4, 0xBE1C, 0xBE84, 0xBECA, 0xBF0C, 0xBFA5, 0xBF80, + /* reversed: e15..e0 original index 0..15 */ + 0x3F80, 0x3E38, 0x3E0C, 0x3DF0, 0x3DAE, 0x3D5F, 0x3D2E, 0x0000, + 0xBD28, 0xBDA4, 0xBE1C, 0xBE84, 0xBECA, 0xBF0C, 0xBFA5, 0xBF80 ); __m512 scales[COLS]; const int64_t K2 = K >> 1; From 76521152aa7a93d2d36ba6ca195371df66347a1c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 11 Nov 2025 11:42:27 +0000 Subject: [PATCH 54/78] update Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 26 ++++++++++++++------------ csrc/cpu_ops.h | 7 +++++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 58971ccb7..e06d5161e 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -298,21 +298,21 @@ struct tinygemm_kernel_nn { __m512 vc_master[ROWS * COLS]; __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit + __m256i fifteen = _mm256_set1_epi8(15); __m512i lut = DATA_TYPE == 1 ? _mm512_set_epi16( - /* reversed: e31..e16 copy e15..e0 */ - 0xBE80, 0xBB4F, 0xBF2A, 0xBF80, 0xBEAA, 0xBF30, 0xBA45, 0x0000, - 0x3E80, 0x3B4F, 0x3F2A, 0x3F80, 0x3EAA, 0x3F30, 0x3A45, 0x0000, - /* reversed: e15..e0 original index 0..15 */ - 0xBE80, 0xBB4F, 0xBF2A, 0xBF80, 0xBEAA, 0xBF30, 0xBA45, 0x0000, - 0x3E80, 0x3B4F, 0x3F2A, 0x3F80, 0x3EAA, 0x3F30, 0x3A45, 0x0000 + 0xBF80, 0x3F80, 0x3F39, 0x3F10, 0x3EE1, 0x3EAD, 0x3E7C, 0x3E24, 0x3DA2, 0x0000, 0xBDBA, 0xBE3D, 0xBE91, 0xBECA, 0xBF06, 0xBF32, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 ) : _mm512_set_epi16( - /* reversed: e31..e16 copy e15..e0 */ - 0x3F80, 0x3E38, 0x3E0C, 0x3DF0, 0x3DAE, 0x3D5F, 0x3D2E, 0x0000, - 0xBD28, 0xBDA4, 0xBE1C, 0xBE84, 0xBECA, 0xBF0C, 0xBFA5, 0xBF80, - /* reversed: e15..e0 original index 0..15 */ - 0x3F80, 0x3E38, 0x3E0C, 0x3DF0, 0x3DAE, 0x3D5F, 0x3D2E, 0x0000, - 0xBD28, 0xBDA4, 0xBE1C, 0xBE84, 0xBECA, 0xBF0C, 0xBFA5, 0xBF80 + 0xBF80, 0x3F80, 0x3F39, 0x3F10, 0x3EE1, 0x3EAD, 0x3E7C, 0x3E24, 0x3DA2, 0x0000, 0xBDBA, 0xBE3D, 0xBE91, 0xBECA, 0xBF06, 0xBF32, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 ); + __m512i bf16_lut = _mm512_set_epi16( + // 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 + 0x0000, 0x4170, 0x4160, 0x4150, 0x4140, 0x4130, 0x4120, 0x4110, 0x4100, 0x40E0, 0x40C0, 0x40A0, 0x4080, 0x4040, 0x4000, 0x3F80, + // 16 .. 31 + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ); + __m512 scales[COLS]; const int64_t K2 = K >> 1; const int64_t lda2 = lda >> 1; @@ -352,6 +352,8 @@ struct tinygemm_kernel_nn { // deinterleave and lookup to BF16 __m256i vb_i8_lo = vb_u4 & mask; __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; + // vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen); + // vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen); vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 9aa401e4d..7178d32a8 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -144,6 +144,13 @@ static inline bf16_t float_to_bf16(float x) { return bf16_t{static_cast(r >> 16)}; } +static float bf16_to_float(uint16_t bf16) { + uint32_t bits = (uint32_t)bf16 << 16; + float f; + std::memcpy(&f, &bits, sizeof(f)); + return f; +} + static inline fp16_t float_to_fp16(float x) { uint32_t bits; std::memcpy(&bits, &x, 4); From ea0e64976717807165e973d4365dfc0241c36281 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 12 Nov 2025 11:16:03 +0000 Subject: [PATCH 55/78] enable gemv 4bit bf16 Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 84 +++-- bitsandbytes/functional.py | 24 +- bitsandbytes/nn/modules.py | 12 +- csrc/cpu_ops.cpp | 518 ++++++++++++++----------------- 4 files changed, 293 insertions(+), 345 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 1080f4f0c..6c09c0aca 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -6,7 +6,6 @@ from bitsandbytes.functional import get_ptr -from ..utils import CODE from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib @@ -190,9 +189,6 @@ def _( ct.c_longlong(shape[0]), ct.c_longlong(shape[1]), ) - # out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) - # if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): - # import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( get_ptr(A), @@ -208,6 +204,7 @@ def _( return out if hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16"): + @register_kernel("bitsandbytes::gemv_4bit", "cpu") def _( A: torch.Tensor, @@ -246,9 +243,9 @@ def _( ct.c_int64(out_strideM), ) elif quant_type == "nf4": - #print(A) - #print(B) - #print(absmax) + # print(A) + # print(B) + # print(absmax) lib.gemv_4bit_inference_cpu_nf4_bf16( ct.c_int64(M), ct.c_int64(N), @@ -267,40 +264,39 @@ def _( return out - -def dequantize_nf4_test( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -): - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] - else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - - out = out.reshape(-1, *shape[1:]).to(dtype) - - return out - +# def unpack_weight_packed_for_cpu(packed_qweight: torch.Tensor, block_n: int = 32): +# """ +# Inverse of convert_weight_packed_for_cpu. +# packed_qweight: (N, K//2) uint8, each byte = (high<<4)|low, both 4-bit values in 0..15 +# returns: qweight_final (N, K) uint8 with original 4-bit values (0..15) +# """ +# assert packed_qweight.dtype == torch.uint8 +# assert packed_qweight.dim() == 2 +# N, K_half = packed_qweight.shape +# assert N % block_n == 0 +# BIT_COUNT = block_n # 32 +# # reshape to rows of 32 packed bytes +# qw = packed_qweight.reshape(-1, BIT_COUNT) # [(N//block_n)*K_half, 32] +# low = (qw & 0x0F) +# high = (qw >> 4) & 0x0F +# # restore 64 nibbles (low first then high, matching original pack order) +# restored = torch.cat([low, high], dim=1) # [..., 64] +# # reshape back (inverse of flatten) +# restored = restored.reshape(N // block_n, K_half, block_n, 2) # [N/block_n, K//2, block_n, 2] +# # inverse transpose +# restored = restored.transpose(-3, -2) # [N/block_n, block_n, K//2, 2] +# # final shape +# qweight_final = restored.reshape(N, K_half * 2).to(torch.uint8) +# return qweight_final + + +# _NF4_QUANT_TABLE = torch.tensor([ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, +# 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 ], dtype=torch.float32) + +# def fused_matmul(x, packed_weight, scales, group_size): +# unpacked_weight = unpack_weight_packed_for_cpu(packed_weight) +# shape = unpacked_weight.shape +# # original_weight = _INT4_0_TO_15_TABLE[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1) +# original_weight = _NF4_QUANT_TABLE[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1) +# res = torch.matmul(x, original_weight.T.to(x.dtype)) +# return res diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index daa965ae5..311e60f26 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2199,9 +2199,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): return out -def convert_weight_packed_for_cpu(qweight: torch.Tensor, - quant_state: QuantState, - block_n: int = 32): +def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32): """ qweight: (K * N / 2) uint8 return: packed_weight @@ -2221,12 +2219,12 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, BIT_COUNT = 32 # (=32 low +32 high) new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] out_shape = [N, K // 2] - qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) - qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) - qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] - high = qw[:, BIT_COUNT:] # high 32 - low = qw[:, :BIT_COUNT] # low 32 - packed = ((high << 4) | low).to(torch.uint8) # combine + qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) + qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) + qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] + high = qw[:, BIT_COUNT:] # high 32 + low = qw[:, :BIT_COUNT] # low 32 + packed = ((high << 4) | low).to(torch.uint8) # combine final_qweight = packed.reshape(out_shape) if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -2234,9 +2232,15 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, if absmax.dtype != torch.float32: absmax = absmax.float() - quant_state.absmax = absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize).T.to(torch.bfloat16) + quant_state.absmax = ( + absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) + .T.to(torch.bfloat16) + .contiguous() + ) quant_state.nested = False delattr(quant_state, "state2") + + quant_state.dtype = torch.bfloat16 return final_qweight, quant_state diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 8f8d098e2..4683e3e7f 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -15,6 +15,7 @@ from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer + from ..cextension import ErrorHandlerMockBNBNativeLibrary, lib T = TypeVar("T", bound="torch.nn.Module") @@ -517,10 +518,17 @@ def forward(self, x: torch.Tensor): quant_state = self.weight.quant_state fix_4bit_weight_quant_state_from_module(self) - if not self.enable_optimized_cpu and not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) and hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16") and not self.training and x.requires_grad == False: + if ( + not self.enable_optimized_cpu + and x.device.type == "cpu" + and not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) + and hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16") + and not self.training + and x.requires_grad == False + ): self.weight.data, quant_state = convert_weight_packed_for_cpu(self.weight.data, quant_state) self.enable_optimized_cpu = True - setattr(quant_state, "enable_optimized_cpu", True) + quant_state.enable_optimized_cpu = True # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e06d5161e..2405cf2d0 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,89 +5,61 @@ using namespace BinSearch; - #if defined(__AVX512F__) #include inline __m256i cvt_fp32_to_fp16(const __m512 src) { return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - } +} inline __m256i cvt_fp32_to_bf16(const __m512 src) { - #if defined(__AVX512BF16__) - return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); - #else - __m512i value = _mm512_castps_si512(src); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - // uint32_t lsb = (input >> 16) & 1; - auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); - // uint32_t rounding_bias = 0x7fff + lsb; - t_value = _mm512_add_epi32(t_value, vec_bias); - // input += rounding_bias; - t_value = _mm512_add_epi32(t_value, value); - // input = input >> 16; - t_value = _mm512_srli_epi32(t_value, 16); - // Check NaN before converting back to bf16 - t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); - return _mm512_cvtusepi32_epi16(t_value); - #endif +#if defined(__AVX512BF16__) + return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); +#else + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +#endif } static inline __m512 set_nf4_lut() { return _mm512_set_ps( - 1.0f, - 0.7229568362236023, - 0.5626170039176941, - 0.44070982933044434, - 0.33791524171829224, - 0.24611230194568634, - 0.16093020141124725, - 0.07958029955625534, - 0.0f, - -0.09105003625154495, - -0.18477343022823334, - -0.28444138169288635, - -0.39491748809814453, - -0.5250730514526367, - -0.6961928009986877, - -1.0f); + 1.0f, 0.7229568362236023, 0.5626170039176941, 0.44070982933044434, 0.33791524171829224, 0.24611230194568634, + 0.16093020141124725, 0.07958029955625534, 0.0f, -0.09105003625154495, -0.18477343022823334, + -0.28444138169288635, -0.39491748809814453, -0.5250730514526367, -0.6961928009986877, -1.0f + ); } + static inline __m512 set_fp4_lut() { return _mm512_set_ps( - -0.2500f, - -0.16666667f, - -0.5000f, - -0.33333333f, - -1.0000f, - -0.66666667f, - -5.208333333e-03f, - 0.0000f, - 0.2500f, - 0.16666667f, - 0.5000f, - 0.33333333f, - 1.0000f, - 0.66666667f, - 5.208333333e-03f, - 0.0000f); + -0.2500f, -0.16666667f, -0.5000f, -0.33333333f, -1.0000f, -0.66666667f, -5.208333333e-03f, 0.0000f, 0.2500f, + 0.16666667f, 0.5000f, 0.33333333f, 1.0000f, 0.66666667f, 5.208333333e-03f, 0.0000f + ); } #endif // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 0 = NF4 template -void dequantizeBlockwise4bitCpu(unsigned char* A, - const float* absmax, - T* out, - long long blocksize, - long long m, - long long n) { - static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, - "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); - if (blocksize <= 0 || m < 0 || n <= 0) return; +void dequantizeBlockwise4bitCpu( + unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n +) { + static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); + if (blocksize <= 0 || m < 0 || n <= 0) + return; #if defined(__AVX512F__) long long dim_0 = m; @@ -99,7 +71,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { __m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut(); constexpr auto k_step = VEC_LEN / 2; // 8 - #pragma omp parallel for +#pragma omp parallel for for (int block_idx = 0; block_idx < dim_0; ++block_idx) { for (int k = 0; k < input_dim_1; k += k_step) { // Load 64 bits of nf4 data and a single scale data @@ -112,10 +84,10 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, uint64_t high = 0; uint64_t low = 0; for (int i = 0; i < 4; ++i) { - low |= ((packed >> (2*i * 4)) & 0xf) << ((2*i+1) * 8); - low |= ((packed >> ((2*i+1) * 4)) & 0xf) << (2*i * 8); - high |= ((packed >> (2*i * 4 + 32)) & 0xf) << ((2*i+1) * 8); - high |= ((packed >> ((2*i+1) * 4 + 32)) & 0xf) << (2*i * 8); + low |= ((packed >> (2 * i * 4)) & 0xf) << ((2 * i + 1) * 8); + low |= ((packed >> ((2 * i + 1) * 4)) & 0xf) << (2 * i * 8); + high |= ((packed >> (2 * i * 4 + 32)) & 0xf) << ((2 * i + 1) * 8); + high |= ((packed >> ((2 * i + 1) * 4 + 32)) & 0xf) << (2 * i * 8); } __m128i packed_128 = _mm_set_epi64x(high, low); __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); @@ -126,13 +98,11 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, // Store results T* pout = &out[block_idx * dim_1 + k * 2]; if constexpr (std::is_same()) { - _mm512_storeu_ps(pout, vout); + _mm512_storeu_ps(pout, vout); } else if constexpr (std::is_same()) { - _mm256_storeu_si256( - (__m256i*)pout, cvt_fp32_to_bf16(vout)); + _mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_bf16(vout)); } else if constexpr (std::is_same()) { - _mm256_storeu_si256( - (__m256i*)pout, cvt_fp32_to_fp16(vout)); + _mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_fp16(vout)); } } } @@ -141,7 +111,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, #endif // Scalar fallback branch long long total = m * n; - #pragma omp parallel for +#pragma omp parallel for for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); float scale = absmax[block_idx / blocksize]; @@ -150,11 +120,9 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, unsigned char byte = A[byte_index]; // High nibble first (matches previous code logic) - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) - : dDequantizeNF4(byte >> 4)) * scale; + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) : dDequantizeNF4(byte >> 4)) * scale; // Low nibble second - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) - : dDequantizeNF4(byte & 0x0F)) * scale; + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) : dDequantizeNF4(byte & 0x0F)) * scale; if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); @@ -177,20 +145,17 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, } } - template -void dequantizeBlockwise8bitCpu(float* code, - unsigned char* A, - const float* absmax, - T* out, - long long blocksize, - long long n) { - if (blocksize <= 0 || n <= 0) return; - // 8-bit path - #pragma omp parallel for +void dequantizeBlockwise8bitCpu( + float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n +) { + if (blocksize <= 0 || n <= 0) + return; +// 8-bit path +#pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); - long long block_end = block_idx + valid_items; + long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; for (long long i = block_idx; i < block_end; ++i) { float v = code[A[i]] * scale; @@ -205,7 +170,6 @@ void dequantizeBlockwise8bitCpu(float* code, } } - void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below @@ -253,171 +217,137 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long } } - #if true // || defined(__AVX512F__) && defined(__AVX512BF16__) #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) -template -struct tinygemm_kernel_nn { - static inline void apply( - const scalar_t*, - const unsigned char*, - scalar_t*, - const scalar_t*, - int64_t, int, int64_t, int64_t, int64_t, int64_t, int64_t) { - static_assert(sizeof(scalar_t) == 0, - "tinygemm_kernel_nn primary template should never be instantiated"); - } +template struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t*, const unsigned char*, scalar_t*, const scalar_t*, int64_t, int, int64_t, int64_t, int64_t, + int64_t, int64_t + ) { + static_assert(sizeof(scalar_t) == 0, "tinygemm_kernel_nn primary template should never be instantiated"); + } }; -template -struct tinygemm_kernel_nn { - static inline void apply( - const bf16_t* __restrict__ A, - const unsigned char* __restrict__ B, - bf16_t* __restrict__ C, - const bf16_t* __restrict__ Bs, - int64_t K, - int group_size, - int64_t lda, - int64_t ldb, - int64_t ldc, - int64_t strideBz, - int64_t strideBs) { - static_assert(BLOCK_N % 32 == 0); - constexpr int ROWS = BLOCK_M; // 32 - constexpr int COLS = BLOCK_N / 16; // 2 - - // prefetch distance - constexpr int PREFETCH_SIZE_K = 16 * 4; - - __m512bh va; - __m512bh vb[COLS]; - __m512 vc[ROWS * COLS]; - __m512 vc_master[ROWS * COLS]; - - __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit - __m256i fifteen = _mm256_set1_epi8(15); - __m512i lut = DATA_TYPE == 1 ? _mm512_set_epi16( - 0xBF80, 0x3F80, 0x3F39, 0x3F10, 0x3EE1, 0x3EAD, 0x3E7C, 0x3E24, 0x3DA2, 0x0000, 0xBDBA, 0xBE3D, 0xBE91, 0xBECA, 0xBF06, 0xBF32, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 - ) : _mm512_set_epi16( - 0xBF80, 0x3F80, 0x3F39, 0x3F10, 0x3EE1, 0x3EAD, 0x3E7C, 0x3E24, 0x3DA2, 0x0000, 0xBDBA, 0xBE3D, 0xBE91, 0xBECA, 0xBF06, 0xBF32, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 - ); - __m512i bf16_lut = _mm512_set_epi16( - // 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 - 0x0000, 0x4170, 0x4160, 0x4150, 0x4140, 0x4130, 0x4120, 0x4110, 0x4100, 0x40E0, 0x40C0, 0x40A0, 0x4080, 0x4040, 0x4000, 0x3F80, - // 16 .. 31 - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 - ); +template struct tinygemm_kernel_nn { + static inline void apply( + const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C, + const bf16_t* __restrict__ Bs, int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, + int64_t strideBz, int64_t strideBs + ) { + static_assert(BLOCK_N % 32 == 0); + constexpr int ROWS = BLOCK_M; // 32 + constexpr int COLS = BLOCK_N / 16; // 2 + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 16 * 4; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vc_master[ROWS * COLS]; + + __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit + __m256i fifteen = _mm256_set1_epi8(15); + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ); + __m512 scales[COLS]; + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const int64_t gs2 = group_size >> 1; // 64 / 2 = 32 + const float* a_ptr = reinterpret_cast(A); + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + vc_master[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + auto pre_compute = [&](auto i, int64_t kgs) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_set1_ps(0.f); // reset accumulator + + // load scales + if constexpr (row == 0 && col % 2 == 0) { + // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 + __m512i tmp = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); + scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0)); + scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1)); + } + }; + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; - __m512 scales[COLS]; - const int64_t K2 = K >> 1; - const int64_t lda2 = lda >> 1; - const int64_t ldb2 = ldb; // ldb * 2 >> 1; - const int64_t gs2 = group_size >> 1; // 64 / 2 = 32 - const float* a_ptr = reinterpret_cast(A); - - auto loadc = [&](auto i) { - constexpr int col = i % COLS; - vc_master[i] = _mm512_set1_ps(0.f); - }; - Unroll{}(loadc); - - auto pre_compute = [&](auto i, int64_t kgs) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - vc[i] = _mm512_set1_ps(0.f); // reset accumulator - - // load scales - if constexpr (row == 0 && col % 2 == 0) { - // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 - __m512i tmp = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); - scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0)); - scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1)); - } - }; - auto compute = [&](auto i, int64_t k) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - - if constexpr (col == 0) { - va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); - } - if constexpr (row == 0 && col % 2 == 0) { - __m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast(B + k * ldb + col * 16)); - - // deinterleave and lookup to BF16 - __m256i vb_i8_lo = vb_u4 & mask; - __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; - // vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen); - // vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen); - vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); - vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); - - if constexpr (PREFETCH_SIZE_K > 0) { - _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0 && col % 2 == 0) { + __m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast(B + k * ldb + col * 16)); + + // deinterleave and lookup to BF16 + __m256i vb_i8_lo = vb_u4 & mask; + __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; + vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen); + vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen); + vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); + vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); + + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + auto post_compute = [&](auto i, int64_t kgs) { + vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]); + }; + for (int64_t k = 0; k < K2; k += gs2) { + Unroll{}(pre_compute, k / gs2); + for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) { + Unroll{}(compute, k + k_offset); + } + Unroll{}(post_compute, k / gs2); } - } - vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); - }; - auto post_compute = [&](auto i, int64_t kgs) { - vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]); - }; - for (int64_t k = 0; k < K2; k += gs2) { - Unroll{}(pre_compute, k / gs2); - for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) { - Unroll{}(compute, k + k_offset); - } - Unroll{}(post_compute, k / gs2); - } - auto storec = [&](auto i) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - if constexpr (col % 2 == 0) { - _mm512_storeu_si512( - reinterpret_cast<__m512i*>(C + row * ldc + col * 16), - (__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i]))); - } - }; - Unroll{}(storec); - } + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>(C + row * ldc + col * 16), + (__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i])) + ); + } + }; + Unroll{}(storec); + } }; -#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \ - tinygemm_kernel_nn::apply( \ - A + mb_start * lda, \ - B + nb_start, \ - C + mb_start * ldc + nb_start, \ - Bs + nb_start, \ - K, \ - group_size, \ - lda, \ - ldb, \ - ldc, \ - strideBz, \ - strideBs); +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, group_size, lda, ldb, ldc, \ + strideBz, strideBs \ + ); template void tinygemm_kernel( - const scalar_t* __restrict__ A, - const unsigned char* __restrict__ B, - scalar_t* __restrict__ C, - const scalar_t* __restrict__ Bs, - scalar_t* __restrict__ Btmp, - float* __restrict__ Ctmp, - int64_t M, - int64_t N, - int64_t K, - int group_size, - int64_t lda, - int64_t ldb, - int64_t ldc, - int64_t strideBz, - int64_t strideBs) { + const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C, + const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBz, int64_t strideBs +) { constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_N = 64; const int64_t MB = div_up(M, BLOCK_M); @@ -430,60 +360,53 @@ void tinygemm_kernel( int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch (mb_size << 4 | nb_size >> 4) { - // mb_size = 1 - case 0x12: + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32, DATA_TYPE); break; - case 0x14: + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64, DATA_TYPE); break; - // mb_size = 2 - case 0x22: + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32, DATA_TYPE); break; - case 0x24: + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64, DATA_TYPE); break; - // mb_size = 3 - case 0x32: + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32, DATA_TYPE); break; - case 0x34: + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64, DATA_TYPE); break; - // mb_size = 4 - case 0x42: + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32, DATA_TYPE); break; - case 0x44: + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64, DATA_TYPE); break; - default: { - std::fprintf(stderr, - "[bitsandbytes] Unexpected block size %lldx%lld\n", - (long long)mb_size, - (long long)nb_size); - std::abort(); // or return; if you prefer silent exit - } + default: { + std::fprintf( + stderr, "[bitsandbytes] Unexpected block size %lldx%lld\n", (long long)mb_size, (long long)nb_size + ); + std::abort(); // or return; if you prefer silent exit + } } } } } template -void gemv_4bit_inference(int64_t M, - int64_t N, - int64_t K, - const T* __restrict__ x, - const unsigned char* __restrict__ w, - const T* __restrict__ absmax, - T* __restrict__ out, - int64_t blocksize, - int64_t x_stride, - int64_t out_stride) { +void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, + const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +) { constexpr int64_t BLOCK_M = block_size_m(); // 32 constexpr int64_t BLOCK_N = block_size_n(); // 32 - const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 + const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 const int64_t NB = div_up(N, BLOCK_N); // TODO: enable brgemm in the future. // const bool use_brgemm = M > 4; @@ -504,7 +427,7 @@ void gemv_4bit_inference(int64_t M, int64_t nb_size = std::min(N - nb_start, BLOCK_N); tinygemm_kernel( /* A */ x + mb_start * x_stride, - /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 + /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 /* C */ out + mb_start * out_stride + nb_start, /* Bs */ absmax + nb_start, /* Btmp */ Btmp_inner, @@ -517,7 +440,8 @@ void gemv_4bit_inference(int64_t M, /* ldb */ nb_size, /* ldc */ out_stride, /* sBz */ N, - /* sBs */ N); + /* sBs */ N + ); } } } @@ -528,43 +452,59 @@ void gemv_4bit_inference(int64_t M, } #endif - //============================================================== // TEMPLATE DEFINITIONS //============================================================== template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n +); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n +); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n +); // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* +// __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* +// __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* +// __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); // template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* +// __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); template void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); template void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); \ No newline at end of file + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); From 9277d24d51f30ff1fafb56c39ca782de632a88f3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Nov 2025 10:02:10 +0000 Subject: [PATCH 56/78] enable avx512 check Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 10 +++++++++- bitsandbytes/nn/modules.py | 7 ++----- csrc/cpu_ops.h | 35 +++++++++++++++++++++++++++++++++++ csrc/pythonInterface.cpp | 37 ++++--------------------------------- 4 files changed, 50 insertions(+), 39 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 311e60f26..b18586c10 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, lib +from .cextension import HIP_ENVIRONMENT, lib, ErrorHandlerMockBNBNativeLibrary name2qmap = {} @@ -2243,5 +2243,13 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState quant_state.dtype = torch.bfloat16 return final_qweight, quant_state +def has_avx512bf16(): + if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) + and hasattr(lib, "has_avx512bf16_cpu") + and lib.has_avx512bf16_cpu(): + return True + + return False + C = 127.0 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 4683e3e7f..39cc741b1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,12 +12,10 @@ import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu +from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu, has_avx512bf16 from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer -from ..cextension import ErrorHandlerMockBNBNativeLibrary, lib - T = TypeVar("T", bound="torch.nn.Module") @@ -521,8 +519,7 @@ def forward(self, x: torch.Tensor): if ( not self.enable_optimized_cpu and x.device.type == "cpu" - and not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) - and hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16") + and has_avx512bf16() and not self.training and x.requires_grad == False ): diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 7178d32a8..60b48eea4 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -293,4 +293,39 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, l void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); #endif +#if defined(__AVX512F__) +#include + +#ifdef _MSC_VER +#include + +static inline bool has_avx512f() { + static bool v = [] { + int info[4]; + __cpuidex(info, 7, 0); + return (info[1] & (1 << 16)) != 0; // EBX bit16 AVX512F + }(); + return v; +} + +static inline bool has_avx512bf16() { + static bool v = [] { + int info[4]; + __cpuidex(info, 7, 1); + return (info[0] & (1 << 5)) != 0; // EAX bit5 AVX512_BF16 + }(); + return v; +} +#else +bool has_avx512f() { + static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); + return supported_avx512f; +} + +bool has_avx512bf16() { + static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); + return supported_avx512bf16; +} +#endif + #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 616db6e64..3b23a2239 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -891,22 +891,6 @@ void cdequantize_blockwise_cpu_nf4_fp16( } #if defined(__AVX512F__) && defined(__AVX512BF16__) -// void gemv_4bit_inference_cpu_fp4_fp32( -// int64_t M, int64_t N, int64_t K, -// const float* __restrict__ x, const unsigned char* __restrict__ w, -// const float* __restrict__ absmax, float* __restrict__ out, -// int64_t blocksize, int64_t x_stride, int64_t out_stride -// ) { -// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); -// } -// void gemv_4bit_inference_cpu_fp4_fp16( -// int64_t M, int64_t N, int64_t K, -// const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, -// const fp16_t* __restrict__ absmax, fp16_t* __restrict__ out, -// int64_t blocksize, int64_t x_stride, int64_t out_stride -// ) { -// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); -// } void gemv_4bit_inference_cpu_fp4_bf16( int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, @@ -915,23 +899,6 @@ void gemv_4bit_inference_cpu_fp4_bf16( ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } - -// void gemv_4bit_inference_cpu_nf4_fp32( -// int64_t M, int64_t N, int64_t K, -// const float* __restrict__ x, const unsigned char* __restrict__ w, -// const float* __restrict__ absmax, float* __restrict__ out, -// int64_t blocksize, int64_t x_stride, int64_t out_stride -// ) { -// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); -// } -// void gemv_4bit_inference_cpu_nf4_fp16( -// int64_t M, int64_t N, int64_t K, -// const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, -// const fp16_t* __restrict__ absmax, fp16_t* __restrict__ out, -// int64_t blocksize, int64_t x_stride, int64_t out_stride -// ) { -// gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); -// } void gemv_4bit_inference_cpu_nf4_bf16( int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, @@ -941,4 +908,8 @@ void gemv_4bit_inference_cpu_nf4_bf16( gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } #endif +#if defined(__AVX512F__) +bool has_avx512f_cpu() return has_avx512f() +bool has_avx512bf16_cpu() return has_avx512bf16() +#endif } From 4fb315bc768bb32d3a1e580522ecdd8f40f78b79 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Nov 2025 10:08:18 +0000 Subject: [PATCH 57/78] fix check Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 8 ++------ bitsandbytes/functional.py | 10 ++++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 6c09c0aca..ad634b146 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -4,7 +4,7 @@ import torch -from bitsandbytes.functional import get_ptr +from bitsandbytes.functional import get_ptr, has_avx512bf16 from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib @@ -203,8 +203,7 @@ def _( return out - if hasattr(lib, "gemv_4bit_inference_cpu_nf4_bf16"): - + if has_avx512bf16(): @register_kernel("bitsandbytes::gemv_4bit", "cpu") def _( A: torch.Tensor, @@ -243,9 +242,6 @@ def _( ct.c_int64(out_strideM), ) elif quant_type == "nf4": - # print(A) - # print(B) - # print(absmax) lib.gemv_4bit_inference_cpu_nf4_bf16( ct.c_int64(M), ct.c_int64(N), diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b18586c10..8d2c9a7a2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, lib, ErrorHandlerMockBNBNativeLibrary +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -2244,12 +2244,10 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState return final_qweight, quant_state def has_avx512bf16(): - if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) - and hasattr(lib, "has_avx512bf16_cpu") - and lib.has_avx512bf16_cpu(): + if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu(): return True - - return False + else: + return False C = 127.0 From 81f19844420a8e004947f76878166b8b88adf630 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Nov 2025 10:13:35 +0000 Subject: [PATCH 58/78] fix endif Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 60b48eea4..677b90954 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -327,5 +327,6 @@ bool has_avx512bf16() { return supported_avx512bf16; } #endif +#endif #endif From 0f78bada2df047b5bca97669b1813e0ec6d519b6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Nov 2025 10:14:59 +0000 Subject: [PATCH 59/78] fix format Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 1 + bitsandbytes/backends/utils.py | 1 - bitsandbytes/functional.py | 1 + bitsandbytes/utils.py | 2 +- csrc/cpu_ops.h | 150 ++++++++++++++++--------------- csrc/pythonInterface.cpp | 24 ++--- 6 files changed, 93 insertions(+), 86 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index ad634b146..1943c5357 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -204,6 +204,7 @@ def _( return out if has_avx512bf16(): + @register_kernel("bitsandbytes::gemv_4bit", "cpu") def _( A: torch.Tensor, diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index 1d6abd5c1..ec96a440c 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -1,7 +1,6 @@ import subprocess from packaging import version -from collections.abc import Sequence import torch try: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8d2c9a7a2..54369770d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2243,6 +2243,7 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState quant_state.dtype = torch.bfloat16 return final_qweight, quant_state + def has_avx512bf16(): if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu(): return True diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index dcb7798af..98ccd7da6 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -3,7 +3,7 @@ import subprocess import torch -from collections.abc import Sequence + def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 677b90954..91e7c1ebc 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,15 +1,15 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H +#include +#include #include #include #include -#include -#include #include #if defined(_OPENMP) - #include +#include #endif // amx-bf16 @@ -17,17 +17,17 @@ #define TILE_N 16 #define TILE_K 32 // work around compiler internal error -#define BLOCK_K 128 // 4 * TILE_K +#define BLOCK_K 128 // 4 * TILE_K // block size for AMX gemm constexpr int block_size_m() { return 2 * TILE_M; } + constexpr int block_size_n() { return 2 * TILE_N; } -template -inline int get_cache_blocks(int chunk_size) { - // L2 2MB and ratio of 50% - const int L2_size = 2048 * 1024 >> 1; - return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); +template inline int get_cache_blocks(int chunk_size) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); } // forced unroll for perf critical path @@ -37,25 +37,22 @@ inline int get_cache_blocks(int chunk_size) { #define ALWAYS_INLINE inline #endif -template -struct Unroll { - template - ALWAYS_INLINE void operator()(const Func& f, Args... args) const { - Unroll{}(f, args...); - f(std::integral_constant{}, args...); - } +template struct Unroll { + template ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } }; -template <> -struct Unroll<1> { - template - ALWAYS_INLINE void operator()(const Func& f, Args... args) const { - f(std::integral_constant{}, args...); - } +template <> struct Unroll<1> { + template ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } }; -template ::value, int>::type = 0> -inline T div_up(T x, T y) { return (x + y - 1) / y; } +template ::value, int>::type = 0> inline T div_up(T x, T y) { + return (x + y - 1) / y; +} inline int get_max_threads() { #if defined(_OPENMP) @@ -68,60 +65,59 @@ inline int get_max_threads() { inline int adjust_num_threads(int m) { int actual_nth = get_max_threads(); - if (m == 1) return actual_nth; + if (m == 1) + return actual_nth; return std::max(1, (actual_nth >> 1) * 2); } -template -inline void parallel_2d(int m, int n, const func_t& f) { - // make sure we have even num_threads - int nth = adjust_num_threads(m); - - // [NOTE] thread blocking: - // - // 1) prefer square block per thread - // 2) use even number of CPU cores - // 3) use all `num_threads` cores - // - // we have: - // TM * TN = T - // BM / TM = BN / TN - // then: - // TM = ((BM / BN) * T) ^ 0.5 - // - float r = float(m) / n; - int nth_m = std::ceil(std::sqrt(r * nth)); - int nth_n = 1; - for (; nth_m > 0; --nth_m) { - nth_n = nth / nth_m; - if (nth_m * nth_n == nth) { - break; +template inline void parallel_2d(int m, int n, const func_t& f) { + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } } - } #if defined(_OPENMP) #pragma omp parallel num_threads(nth) - { - int ith = omp_get_thread_num(); - int ith_m = ith / nth_n; - int ith_n = ith % nth_n; + { + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; - int thread_block_m = div_up(m, nth_m); - int thread_block_n = div_up(n, nth_n); + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); - int begin_m = ith_m * thread_block_m; - int end_m = std::min(m, begin_m + thread_block_m); - int begin_n = ith_n * thread_block_n; - int end_n = std::min(n, begin_n + thread_block_n); + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); - f(begin_m, end_m, begin_n, end_n); - } + f(begin_m, end_m, begin_n, end_n); + } #else - f(0, m, 0, n); + f(0, m, 0, n); #endif } - void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); typedef enum DataType_t { @@ -155,17 +151,17 @@ static inline fp16_t float_to_fp16(float x) { uint32_t bits; std::memcpy(&bits, &x, 4); uint32_t sign = (bits >> 31) & 0x1; - uint32_t exp = (bits >> 23) & 0xFF; + uint32_t exp = (bits >> 23) & 0xFF; uint32_t mant = bits & 0x7FFFFF; uint16_t h; - if (exp == 0xFF) { // Inf / NaN + if (exp == 0xFF) { // Inf / NaN uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa h = (sign << 15) | (0x1F << 10) | mant16; - } else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142) + } else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142) h = (sign << 15) | (0x1F << 10); // Inf - } else if (exp < 0x71) { // subnormal or zero (exp_f < 113) - if (exp < 0x67) { // too small -> zero (exp_f < 103) + } else if (exp < 0x71) { // subnormal or zero (exp_f < 113) + if (exp < 0x67) { // too small -> zero (exp_f < 103) h = (sign << 15); } else { // subnormal: implicit leading 1 @@ -281,16 +277,22 @@ inline float dDequantizeNF4(unsigned char val) { return -1.0f; //*0000 } - template -void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); +void dequantizeBlockwise8bitCpu( + float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n +); template -void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +void dequantizeBlockwise4bitCpu( + unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n +); #if defined(__AVX512F__) && defined(__AVX512BF16__) - template - void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +template +void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, + const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); #endif #if defined(__AVX512F__) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 3b23a2239..c07666cdf 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -847,11 +847,13 @@ void cdequantize_blockwise_cpu_fp32( ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { @@ -863,11 +865,13 @@ void cdequantize_blockwise_cpu_fp4_fp32( ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_fp4_bf16( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_fp4_fp16( unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { @@ -879,11 +883,13 @@ void cdequantize_blockwise_cpu_nf4_fp32( ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_nf4_bf16( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_nf4_fp16( unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { @@ -892,24 +898,22 @@ void cdequantize_blockwise_cpu_nf4_fp16( #if defined(__AVX512F__) && defined(__AVX512BF16__) void gemv_4bit_inference_cpu_fp4_bf16( - int64_t M, int64_t N, int64_t K, - const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, - const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, - int64_t blocksize, int64_t x_stride, int64_t out_stride + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } + void gemv_4bit_inference_cpu_nf4_bf16( - int64_t M, int64_t N, int64_t K, - const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, - const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, - int64_t blocksize, int64_t x_stride, int64_t out_stride + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride ) { gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); } #endif #if defined(__AVX512F__) -bool has_avx512f_cpu() return has_avx512f() -bool has_avx512bf16_cpu() return has_avx512bf16() +bool has_avx512f_cpu() { return has_avx512f() } + +bool has_avx512bf16_cpu() { return has_avx512bf16() } #endif } From fcb84565397cd89fa616b2488f496bcdbd8c449c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Nov 2025 10:16:01 +0000 Subject: [PATCH 60/78] fix format Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index c07666cdf..a9515325c 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -912,8 +912,8 @@ void gemv_4bit_inference_cpu_nf4_bf16( } #endif #if defined(__AVX512F__) -bool has_avx512f_cpu() { return has_avx512f() } +bool has_avx512f_cpu() { return has_avx512f(); } -bool has_avx512bf16_cpu() { return has_avx512bf16() } +bool has_avx512bf16_cpu() { return has_avx512bf16(); } #endif } From c5e18945bb32f83453782394816a22cb1db7694c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Nov 2025 10:20:01 +0000 Subject: [PATCH 61/78] fix def Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 91e7c1ebc..503ec556b 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -319,12 +319,12 @@ static inline bool has_avx512bf16() { return v; } #else -bool has_avx512f() { +static inline bool has_avx512f() { static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); return supported_avx512f; } -bool has_avx512bf16() { +static inline bool has_avx512bf16() { static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); return supported_avx512bf16; } From df1d669a82d14fdfa57e329637391ff15e8f69da Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 14 Nov 2025 09:42:08 +0000 Subject: [PATCH 62/78] fix position Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 16 ++-------------- csrc/cpu_ops.h | 16 ++++++++-------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 31a5deca0..bc0a664ba 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -529,19 +528,7 @@ template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ); -// template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* -// __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); -// template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float* -// __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); -// -// template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* -// __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); -// template void gemv_4bit_inference( -// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float* -// __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); +#if defined(__AVX512F__) && defined(__AVX512BF16__) template void gemv_4bit_inference( int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride @@ -550,3 +537,4 @@ template void gemv_4bit_inference( int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride ); +#endif diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 5159af858..1e38493f3 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -283,14 +283,6 @@ void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n ); -#if defined(__AVX512F__) && defined(__AVX512BF16__) -template -void gemv_4bit_inference( - int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, - const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride -); -#endif - #if defined(__AVX512F__) #include @@ -327,4 +319,12 @@ static inline bool has_avx512bf16() { #endif #endif +#if defined(__AVX512F__) && defined(__AVX512BF16__) +template +void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, + const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); +#endif + #endif From bb3ac8da08d2a22096e33de02f9f2bcaa9374ac4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 14 Nov 2025 09:42:33 +0000 Subject: [PATCH 63/78] fix format Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 1e38493f3..66ff32d04 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,11 +3,11 @@ #include #include +#include #include #include #include #include -#include #if defined(_OPENMP) #include From 26b56852bd4ac99ad50992a4d8957dfeb9a68214 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 14 Nov 2025 13:14:04 +0000 Subject: [PATCH 64/78] rm duplicated func Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index bc0a664ba..f569bf681 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -14,38 +14,6 @@ using namespace BinSearch; #if defined(__AVX512F__) #include -#ifdef _MSC_VER -#include - -static inline bool has_avx512f() { - static bool v = [] { - int info[4]; - __cpuidex(info, 7, 0); - return (info[1] & (1 << 16)) != 0; // EBX bit16 AVX512F - }(); - return v; -} - -static inline bool has_avx512bf16() { - static bool v = [] { - int info[4]; - __cpuidex(info, 7, 1); - return (info[0] & (1 << 5)) != 0; // EAX bit5 AVX512_BF16 - }(); - return v; -} -#else -bool has_avx512f() { - static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); - return supported_avx512f; -} - -bool has_avx512bf16() { - static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); - return supported_avx512bf16; -} -#endif - inline __m256i cvt_fp32_to_fp16(const __m512 src) { return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } From 580010cc0a391534f06cab77c009c73b6174f8df Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Nov 2025 09:35:50 +0000 Subject: [PATCH 65/78] rm useless code comments Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 37 -------------------------------- 1 file changed, 37 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 4a47b7d75..cece524da 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -275,40 +275,3 @@ def _( out = out.to(dtype) return out - -# def unpack_weight_packed_for_cpu(packed_qweight: torch.Tensor, block_n: int = 32): -# """ -# Inverse of convert_weight_packed_for_cpu. -# packed_qweight: (N, K//2) uint8, each byte = (high<<4)|low, both 4-bit values in 0..15 -# returns: qweight_final (N, K) uint8 with original 4-bit values (0..15) -# """ -# assert packed_qweight.dtype == torch.uint8 -# assert packed_qweight.dim() == 2 -# N, K_half = packed_qweight.shape -# assert N % block_n == 0 -# BIT_COUNT = block_n # 32 -# # reshape to rows of 32 packed bytes -# qw = packed_qweight.reshape(-1, BIT_COUNT) # [(N//block_n)*K_half, 32] -# low = (qw & 0x0F) -# high = (qw >> 4) & 0x0F -# # restore 64 nibbles (low first then high, matching original pack order) -# restored = torch.cat([low, high], dim=1) # [..., 64] -# # reshape back (inverse of flatten) -# restored = restored.reshape(N // block_n, K_half, block_n, 2) # [N/block_n, K//2, block_n, 2] -# # inverse transpose -# restored = restored.transpose(-3, -2) # [N/block_n, block_n, K//2, 2] -# # final shape -# qweight_final = restored.reshape(N, K_half * 2).to(torch.uint8) -# return qweight_final - - -# _NF4_QUANT_TABLE = torch.tensor([ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, -# 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 ], dtype=torch.float32) - -# def fused_matmul(x, packed_weight, scales, group_size): -# unpacked_weight = unpack_weight_packed_for_cpu(packed_weight) -# shape = unpacked_weight.shape -# # original_weight = _INT4_0_TO_15_TABLE[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1) -# original_weight = _NF4_QUANT_TABLE[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1) -# res = torch.matmul(x, original_weight.T.to(x.dtype)) -# return res From 57b89bfa3e5061aacb16ce19b695ae99d1be3f86 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 19 Nov 2025 15:47:17 +0000 Subject: [PATCH 66/78] fix out shape Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index cece524da..db182a95c 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -236,6 +236,7 @@ def _( if dtype != torch.bfloat16: A = A.to(torch.bfloat16) + final_out_shape = (*A.shape[:-1], shapeB[0]) A = A.reshape(-1, A.shape[-1]) out_shape = (*A.shape[:-1], shapeB[0]) out = torch.empty(out_shape, dtype=A.dtype, device=A.device) @@ -274,4 +275,4 @@ def _( if dtype != torch.bfloat16: out = out.to(dtype) - return out + return out.reshape(final_out_shape) From de5fb9c9c97af7b8e40dea1714d4b76fd51742c8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 09:13:21 +0000 Subject: [PATCH 67/78] fix comments Signed-off-by: jiqing-feng --- CMakeLists.txt | 15 +++------------ bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/functional.py | 2 +- bitsandbytes/nn/modules.py | 16 ++++++++-------- 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d9b34332..f88ac2b11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -280,24 +280,15 @@ if (BUILD_CPU) include(CheckCXXCompilerFlag) check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG) check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG) - check_cxx_compiler_flag(-mavx512dq HAS_AVX512DQ) - check_cxx_compiler_flag(-mavx512bw HAS_AVX512BW) - check_cxx_compiler_flag(-mavx512vl HAS_AVX512VL) if (HAS_AVX512F_FLAG) target_compile_options(bitsandbytes PRIVATE -mavx512f) - endif() - if (HAS_AVX512BF16_FLAG) - target_compile_options(bitsandbytes PRIVATE -mavx512bf16) - endif() - if(HAS_AVX512DQ) target_compile_options(bitsandbytes PRIVATE -mavx512dq) - endif() - if(HAS_AVX512BW) target_compile_options(bitsandbytes PRIVATE -mavx512bw) - endif() - if(HAS_AVX512VL) target_compile_options(bitsandbytes PRIVATE -mavx512vl) endif() + if (HAS_AVX512BF16_FLAG) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + endif() target_compile_options( bitsandbytes PRIVATE -mprefer-vector-width=256 diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f127327be..0b89c012a 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -378,7 +378,7 @@ def matmul_4bit( if A.device.type == "cpu": quant_state.dtype = A.dtype - if getattr(quant_state, "enable_optimized_cpu", False): + if getattr(quant_state, "packing_format_for_cpu", False): out = F.gemv_4bit(A, B, out, state=quant_state) if bias is not None: out += bias diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 15ff3a181..509379527 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2103,7 +2103,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): return out -def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32): +def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32): """ qweight: (K * N / 2) uint8 return: packed_weight diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 296c92b40..ed487323b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,7 +12,7 @@ import bitsandbytes as bnb from bitsandbytes.cextension import ROCM_WARP_SIZE_64 -from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu, has_avx512bf16 +from bitsandbytes.functional import QuantState, _convert_weight_packed_for_cpu, has_avx512bf16 from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer @@ -479,7 +479,7 @@ def __init__( self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage - self.enable_optimized_cpu = False + self.packing_format_for_cpu = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -513,19 +513,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[prefix + "weight." + k] = v if keep_vars else v.detach() def forward(self, x: torch.Tensor): - quant_state = self.weight.quant_state fix_4bit_weight_quant_state_from_module(self) + quant_state = self.weight.quant_state if ( - not self.enable_optimized_cpu + not self.packing_format_for_cpu and x.device.type == "cpu" and has_avx512bf16() and not self.training and x.requires_grad == False ): - self.weight.data, quant_state = convert_weight_packed_for_cpu(self.weight.data, quant_state) - self.enable_optimized_cpu = True - quant_state.enable_optimized_cpu = True + self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state) + self.packing_format_for_cpu = True + quant_state.packing_format_for_cpu = True # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: @@ -540,7 +540,7 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - weight = self.weight if getattr(quant_state, "enable_optimized_cpu", False) else self.weight.t() + weight = self.weight if getattr(quant_state, "packing_format_for_cpu", False) else self.weight.t() return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) From 6858a90b51de6927eedd00ea549b18cb0e3ce669 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 10:32:47 +0000 Subject: [PATCH 68/78] add reverse format Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 68 ++++++++++++++++++++++++++++++++++++++ bitsandbytes/nn/modules.py | 21 ++++++++---- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 509379527..449178504 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2109,6 +2109,10 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat return: packed_weight """ assert qweight.dtype == torch.uint8, "qweight must be uint8" + quant_state.original_dtype = quant_state.dtype + quant_state.original_nested = quant_state.nested + quant_state.original_qshape = qweight.shape + qweight = qweight.reshape(-1) unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device) unpacked_w[1::2] = qweight & 0xF @@ -2145,9 +2149,73 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat delattr(quant_state, "state2") quant_state.dtype = torch.bfloat16 + quant_state.packing_format_for_cpu = True return final_qweight, quant_state +def _convert_weight_packed_for_cpu_inverse( + packed_weight: torch.Tensor, + quant_state: QuantState, + block_n: int = 32, +) -> tuple[torch.Tensor, QuantState]: + """ + packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight) + quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu` + Returns: + qweight: [*, N, K] uint8, original qweight shape (quant_state.shape) + recovered_state: QuantState with partially restored fields (best-effort inverse) + """ + assert quant_state.packing_format_for_cpu, "only for packing format" + assert packed_weight.dtype == torch.uint8 + assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]" + N, K_half = packed_weight.shape + K = K_half * 2 + + # 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2] + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low + 32 high) + + assert N % BLOCK_N == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + + # [N, K/2] -> [-1, 64] (32 low + 32 high) + packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64] + # split high/low nibbles + high = (packed >> 4) & 0xF + low = packed & 0xF + # concatenate to [..., 64], first 32 are low, last 32 are high + qw = torch.cat([low, high], dim=-1).to(torch.uint8) # [..., 64] + + # -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K] + qw = qw.reshape(N // BLOCK_N, K_half, BLOCK_N, 2) # [N/B, K/2, B, 2] + qw = qw.transpose(-3, -2).contiguous() # [N/B, B, K/2, 2] + qw = qw.reshape(N, K) # [N, K] + + qweight = qw # [N, K] + + unpacked_w = qweight.reshape(-1).to(torch.int32) # [K*N] + high4 = (unpacked_w[::2] & 0xF).to(torch.uint8) + low4 = (unpacked_w[1::2] & 0xF).to(torch.uint8) + qweight = (high4 << 4) | low4 # [K*N/2] + + # 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.) + recovered_state = quant_state + + # quantize absmax + if recovered_state.original_nested: + absmax = recovered_state.absmax.T.reshape(-1).to(recovered_state.original_dtype) + offset = absmax.mean() + qabsmax, state2 = quantize_blockwise(absmax - offset, blocksize=256) + recovered_state.absmax = qabsmax + recovered_state.offset = offset + recovered_state.state2 = state2 + + recovered_state.dtype = recovered_state.original_dtype + recovered_state.packing_format_for_cpu = False + + return qweight.to(torch.uint8).reshape(recovered_state.original_qshape), recovered_state + + def has_avx512bf16(): if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu(): return True diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ed487323b..b613606c2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,7 +12,12 @@ import bitsandbytes as bnb from bitsandbytes.cextension import ROCM_WARP_SIZE_64 -from bitsandbytes.functional import QuantState, _convert_weight_packed_for_cpu, has_avx512bf16 +from bitsandbytes.functional import ( + QuantState, + _convert_weight_packed_for_cpu, + _convert_weight_packed_for_cpu_inverse, + has_avx512bf16, +) from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer @@ -311,9 +316,13 @@ def cpu(self): return self.to(device="cpu") def cuda(self, device: Optional[int | device | str] = None, non_blocking: bool = False): + if getattr(self.quant_state, "packing_format_for_cpu", False): + self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state) return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) def xpu(self, device: Optional[int | device | str] = None, non_blocking: bool = False): + if getattr(self.quant_state, "packing_format_for_cpu", False): + self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state) return self.to(device="xpu" if device is None else device, non_blocking=non_blocking) @overload @@ -479,7 +488,6 @@ def __init__( self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage - self.packing_format_for_cpu = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -507,7 +515,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): then fill state_dict with components of quant_state """ super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias - + if getattr(self.weight.quant_state, "packing_format_for_cpu", False): + self.weight.data, self.weight.quant_state = _convert_weight_packed_for_cpu_inverse( + self.weight.data, self.weight.quant_state + ) if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() @@ -517,15 +528,13 @@ def forward(self, x: torch.Tensor): quant_state = self.weight.quant_state if ( - not self.packing_format_for_cpu + not getattr(quant_state, "packing_format_for_cpu", False) and x.device.type == "cpu" and has_avx512bf16() and not self.training and x.requires_grad == False ): self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state) - self.packing_format_for_cpu = True - quant_state.packing_format_for_cpu = True # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: From 3b3d609b9d16cf210453c805b1033b72a480888a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 10:51:24 +0000 Subject: [PATCH 69/78] check avx512bf15 Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 4 ++++ csrc/pythonInterface.cpp | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 66ff32d04..6803b29f9 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -298,6 +298,7 @@ static inline bool has_avx512f() { return v; } +#if defined(__AVX512BF16__) static inline bool has_avx512bf16() { static bool v = [] { int info[4]; @@ -306,18 +307,21 @@ static inline bool has_avx512bf16() { }(); return v; } +#endif #else static inline bool has_avx512f() { static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); return supported_avx512f; } +#if defined(__AVX512BF16__) static inline bool has_avx512bf16() { static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); return supported_avx512bf16; } #endif #endif +#endif #if defined(__AVX512F__) && defined(__AVX512BF16__) template diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 5450897a4..07c79fc95 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -909,7 +909,8 @@ void gemv_4bit_inference_cpu_nf4_bf16( #endif #if defined(__AVX512F__) bool has_avx512f_cpu() { return has_avx512f(); } - +#if defined(__AVX512BF16__) bool has_avx512bf16_cpu() { return has_avx512bf16(); } #endif +#endif } From fbb911b69e032c831cf4e9b8dd8a2fb814d907cb Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 12:12:39 +0000 Subject: [PATCH 70/78] fix has_avx512bf16 Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 449178504..4baf7746f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2217,10 +2217,15 @@ def _convert_weight_packed_for_cpu_inverse( def has_avx512bf16(): - if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu(): - return True - else: - return False + """ + Try calling native lib.has_avx512bf16_cpu(). + Return False explicitly if symbol missing or call fails. + """ + try: + support_avx_bf16 = lib.has_avx512bf16_cpu() + except (AttributeError, RuntimeError, OSError): + support_avx_bf16 = False + return support_avx_bf16 C = 127.0 From 3179b42be6ed0c5ee2e81099abf4e48587da5e6f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 15:44:56 +0000 Subject: [PATCH 71/78] fix tests Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- bitsandbytes/functional.py | 10 ++++++++-- tests/test_functional.py | 4 ++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index db182a95c..def87045c 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -229,7 +229,7 @@ def _( code: torch.Tensor, blocksize: int, ) -> torch.Tensor: - # Applied from dequantize_4bit + assert B.dtype == torch.uint8, "Only support uint8 qweight" dtype = A.dtype quant_type = "fp4" if code[1] > 0 else "nf4" # cpu fused op only support bf16 for now. diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 4baf7746f..cdbd7c2f7 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2108,7 +2108,9 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat qweight: (K * N / 2) uint8 return: packed_weight """ - assert qweight.dtype == torch.uint8, "qweight must be uint8" + if qweight.dtype != torch.uint8: + quant_state.original_storage_type = qweight.dtype + qweight = qweight.view(torch.uint8) quant_state.original_dtype = quant_state.dtype quant_state.original_nested = quant_state.nested quant_state.original_qshape = qweight.shape @@ -2200,6 +2202,7 @@ def _convert_weight_packed_for_cpu_inverse( # 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.) recovered_state = quant_state + qweight = qweight.to(torch.uint8).reshape(recovered_state.original_qshape) # quantize absmax if recovered_state.original_nested: @@ -2213,7 +2216,10 @@ def _convert_weight_packed_for_cpu_inverse( recovered_state.dtype = recovered_state.original_dtype recovered_state.packing_format_for_cpu = False - return qweight.to(torch.uint8).reshape(recovered_state.original_qshape), recovered_state + if getattr(recovered_state, "original_storage_type", None): + qweight = qweight.view(recovered_state.original_storage_type) + + return qweight, recovered_state def has_avx512bf16(): diff --git a/tests/test_functional.py b/tests/test_functional.py index d420ff352..55964818c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1318,6 +1318,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double quant_storage=quant_storage, ) C3 = torch.matmul(A, B.t()) + # CPU requires convert weight packed for gemv + if device == "cpu" and F.has_avx512bf16(): + qB, state = F._convert_weight_packed_for_cpu(qB, state) + qB = qB.t() C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) From 0c88d436d103a0ea55b25e135a2e090ce565eeb0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 15:58:27 +0000 Subject: [PATCH 72/78] fix absmax shhape Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index cdbd7c2f7..f97d27cca 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2142,14 +2142,16 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat if absmax.dtype != torch.float32: absmax = absmax.float() - quant_state.absmax = ( - absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) - .T.to(torch.bfloat16) - .contiguous() - ) + quant_state.absmax = absmax quant_state.nested = False delattr(quant_state, "state2") + quant_state.absmax = ( + quant_state.absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) + .T.to(torch.bfloat16) + .contiguous() + ) + quant_state.dtype = torch.bfloat16 quant_state.packing_format_for_cpu = True return final_qweight, quant_state From feb8ad22c9d6911c175f7e03f1177363a58f3735 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 16:21:19 +0000 Subject: [PATCH 73/78] fix compile Signed-off-by: jiqing-feng --- bitsandbytes/nn/modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b613606c2..1c9fac799 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -488,6 +488,7 @@ def __init__( self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage + self.support_avx512bf16_for_cpu = has_avx512bf16() def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -530,7 +531,7 @@ def forward(self, x: torch.Tensor): if ( not getattr(quant_state, "packing_format_for_cpu", False) and x.device.type == "cpu" - and has_avx512bf16() + and self.support_avx512bf16_for_cpu and not self.training and x.requires_grad == False ): From c6b714d8ecee2934b53229b3bbdcda43846649d8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 16:38:18 +0000 Subject: [PATCH 74/78] fix tests Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0b89c012a..da168e17b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -378,11 +378,13 @@ def matmul_4bit( if A.device.type == "cpu": quant_state.dtype = A.dtype - if getattr(quant_state, "packing_format_for_cpu", False): - out = F.gemv_4bit(A, B, out, state=quant_state) - if bias is not None: - out += bias - return out + if getattr(quant_state, "packing_format_for_cpu", False): + out = F.gemv_4bit(A, B, out, state=quant_state) + if bias is not None: + out += bias + return out + else: + return MatMul4Bit.apply(A, B, out, bias, quant_state) if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: From 54971118a73746472983152daf3a11f34db43056 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 20 Nov 2025 16:50:13 +0000 Subject: [PATCH 75/78] fix test_gemv Signed-off-by: jiqing-feng --- tests/test_ops.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_ops.py b/tests/test_ops.py index da589005e..8d9aa5ab2 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -219,11 +219,26 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): out_features = 1024 in_features = 256 + if device == "cpu" and blocksize > in_features: + pytest.skip("CPU implementation only suppoer blocksize <= in_features") + A = torch.randn((1, 1, in_features), dtype=dtype, device=device) B = torch.randn((out_features, in_features), dtype=dtype, device=A.device) B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype) code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize) + if device == "cpu" and bitsandbytes.functional.has_avx512bf16(): + state = bitsandbytes.functional.QuantState( + absmax=absmax, + shape=B.shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state) + B_q = B_q.t() + absmax = state.absmax out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize) assert out.device == A.device From bdb25c045560d6822c051e9abb5910d97ab3be91 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 25 Nov 2025 13:28:04 +0000 Subject: [PATCH 76/78] disable binsearch Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 105 +++++++++++++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 35 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f569bf681..9bf784f7a 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -2,6 +2,10 @@ #include #include +#include +#include +#include + #ifdef HAS_OPENMP #include #define BNB_OMP_PARALLEL_FOR _Pragma("omp parallel for") @@ -9,7 +13,29 @@ #define BNB_OMP_PARALLEL_FOR #endif -using namespace BinSearch; +namespace { + +constexpr int kCodebookSize = 256; + +inline unsigned char lookup_code_index(const float* codebook, float value) { + value = std::clamp(value, -1.0f, 1.0f); + const float* begin = codebook; + const float* end = codebook + kCodebookSize; + const float* right = std::lower_bound(begin, end, value); + if (right == begin) { + return 0; + } + if (right == end) { + return static_cast(kCodebookSize - 1); + } + const float* left = right - 1; + const float dist_left = std::fabs(value - *left); + const float dist_right = std::fabs(*right - value); + const unsigned char idx = static_cast(right - begin); + return dist_right < dist_left ? idx : idx - 1; +} + +} #if defined(__AVX512F__) #include @@ -181,48 +207,57 @@ void dequantizeBlockwise8bitCpu( void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { - // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below + if (blocksize <= 0 || n <= 0) + return; + + // Ensure we cover the full expected dynamic range of the codebook. code[0] = -1.0f; - long long num_blocks = n / blocksize; - num_blocks += n % blocksize == 0 ? 0 : 1; + const auto process_block = [&](long long block_start, long long block_end) { + float absmax_block = 0.0f; + for (long long i = block_start; i < block_end; ++i) { + absmax_block = std::max(absmax_block, std::fabs(A[i])); + } + + long long absmax_idx = block_start / blocksize; + absmax[absmax_idx] = absmax_block; + + if (absmax_block == 0.0f) { + std::fill(out + block_start, out + block_end, 0); + return; + } - const uint32 elements_code = 256; - BinAlgo bin_searcher(code, elements_code); + const float inv_absmax = 1.0f / absmax_block; + for (long long i = block_start; i < block_end; ++i) { + float normed_value = A[i] * inv_absmax; + out[i] = lookup_code_index(code, normed_value); + } + }; + + const long long num_blocks = (n + blocksize - 1) / blocksize; + const int thread_wave_size = 256; - int thread_wave_size = 256; - // we chunk the threads into waves of 256 since the max limit is - // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) + // We chunk the threads into waves of 256 since the max limit is between 16k and 64k on Linux + // (we reach this when running BLOOM-176B with a large batch size). for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) { - long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; - std::vector threads(valid_chunks); - std::vector args(valid_chunks); - - int chunks_processed = 0; - for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; - long long block_end = block_idx + valid_items; - - struct quantize_block_args& arg = args[chunks_processed]; - arg.bin_searcher = &bin_searcher; - arg.code = code; - arg.A = A; - arg.absmax = absmax; - arg.out = out; - arg.block_end = block_end; - arg.block_idx = block_idx; - arg.threadidx = block_idx / blocksize; - arg.blocksize = blocksize; - - threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); - chunks_processed += 1; - if (chunks_processed == valid_chunks) { + const long long wave_blocks = std::min(thread_wave_size, num_blocks - offset); + std::vector threads; + threads.reserve(wave_blocks); + + const long long first_block_start = offset * blocksize; + for (long long b = 0; b < wave_blocks; ++b) { + const long long block_start = first_block_start + b * blocksize; + if (block_start >= n) break; - } + const long long block_end = std::min(block_start + blocksize, n); + threads.emplace_back(process_block, block_start, block_end); } - for (int i = 0; i < valid_chunks; i++) - threads[i].join(); + for (auto& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } } } From 6cec12dc40ff5c1ea2a928e0ee420f8f2fab93ee Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 25 Nov 2025 13:28:44 +0000 Subject: [PATCH 77/78] fix lint Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 9bf784f7a..eb7225d7f 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -35,7 +35,7 @@ inline unsigned char lookup_code_index(const float* codebook, float value) { return dist_right < dist_left ? idx : idx - 1; } -} +} // namespace #if defined(__AVX512F__) #include From 692a8e152540b10e7ef0ec79e2909420d89d7822 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 26 Nov 2025 10:06:44 +0000 Subject: [PATCH 78/78] fix save Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 1 + bitsandbytes/nn/modules.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f97d27cca..2a2a40273 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2214,6 +2214,7 @@ def _convert_weight_packed_for_cpu_inverse( recovered_state.absmax = qabsmax recovered_state.offset = offset recovered_state.state2 = state2 + recovered_state.nested = True recovered_state.dtype = recovered_state.original_dtype recovered_state.packing_format_for_cpu = False diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1c9fac799..d3332acfe 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -515,11 +515,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ - super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight.quant_state, "packing_format_for_cpu", False): self.weight.data, self.weight.quant_state = _convert_weight_packed_for_cpu_inverse( self.weight.data, self.weight.quant_state ) + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach()