Skip to content

Commit 0f78bad

Browse files
committed
fix format
Signed-off-by: jiqing-feng <[email protected]>
1 parent 81f1984 commit 0f78bad

File tree

6 files changed

+93
-86
lines changed

6 files changed

+93
-86
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def _(
204204
return out
205205

206206
if has_avx512bf16():
207+
207208
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
208209
def _(
209210
A: torch.Tensor,

bitsandbytes/backends/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import subprocess
22

33
from packaging import version
4-
from collections.abc import Sequence
54
import torch
65

76
try:

bitsandbytes/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,7 @@ def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState
22432243
quant_state.dtype = torch.bfloat16
22442244
return final_qweight, quant_state
22452245

2246+
22462247
def has_avx512bf16():
22472248
if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu():
22482249
return True

bitsandbytes/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import subprocess
44

55
import torch
6-
from collections.abc import Sequence
6+
77

88
def outlier_hook(module, input):
99
assert isinstance(module, torch.nn.Linear)

csrc/cpu_ops.h

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
#ifndef BITSANDBYTES_CPU_OPS_H
22
#define BITSANDBYTES_CPU_OPS_H
33

4+
#include <algorithm>
5+
#include <cmath>
46
#include <cstdint>
57
#include <cstring>
68
#include <thread>
7-
#include <cmath>
8-
#include <algorithm>
99
#include <type_traits>
1010

1111
#if defined(_OPENMP)
12-
#include <omp.h>
12+
#include <omp.h>
1313
#endif
1414

1515
// amx-bf16
1616
#define TILE_M 16
1717
#define TILE_N 16
1818
#define TILE_K 32
1919
// work around compiler internal error
20-
#define BLOCK_K 128 // 4 * TILE_K
20+
#define BLOCK_K 128 // 4 * TILE_K
2121

2222
// block size for AMX gemm
2323
constexpr int block_size_m() { return 2 * TILE_M; }
24+
2425
constexpr int block_size_n() { return 2 * TILE_N; }
2526

26-
template <typename T>
27-
inline int get_cache_blocks(int chunk_size) {
28-
// L2 2MB and ratio of 50%
29-
const int L2_size = 2048 * 1024 >> 1;
30-
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
27+
template <typename T> inline int get_cache_blocks(int chunk_size) {
28+
// L2 2MB and ratio of 50%
29+
const int L2_size = 2048 * 1024 >> 1;
30+
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
3131
}
3232

3333
// forced unroll for perf critical path
@@ -37,25 +37,22 @@ inline int get_cache_blocks(int chunk_size) {
3737
#define ALWAYS_INLINE inline
3838
#endif
3939

40-
template <int n>
41-
struct Unroll {
42-
template <typename Func, typename... Args>
43-
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
44-
Unroll<n - 1>{}(f, args...);
45-
f(std::integral_constant<int, n - 1>{}, args...);
46-
}
40+
template <int n> struct Unroll {
41+
template <typename Func, typename... Args> ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
42+
Unroll<n - 1>{}(f, args...);
43+
f(std::integral_constant<int, n - 1>{}, args...);
44+
}
4745
};
4846

49-
template <>
50-
struct Unroll<1> {
51-
template <typename Func, typename... Args>
52-
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
53-
f(std::integral_constant<int, 0>{}, args...);
54-
}
47+
template <> struct Unroll<1> {
48+
template <typename Func, typename... Args> ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
49+
f(std::integral_constant<int, 0>{}, args...);
50+
}
5551
};
5652

57-
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
58-
inline T div_up(T x, T y) { return (x + y - 1) / y; }
53+
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> inline T div_up(T x, T y) {
54+
return (x + y - 1) / y;
55+
}
5956

6057
inline int get_max_threads() {
6158
#if defined(_OPENMP)
@@ -68,60 +65,59 @@ inline int get_max_threads() {
6865

6966
inline int adjust_num_threads(int m) {
7067
int actual_nth = get_max_threads();
71-
if (m == 1) return actual_nth;
68+
if (m == 1)
69+
return actual_nth;
7270
return std::max(1, (actual_nth >> 1) * 2);
7371
}
7472

75-
template <typename func_t>
76-
inline void parallel_2d(int m, int n, const func_t& f) {
77-
// make sure we have even num_threads
78-
int nth = adjust_num_threads(m);
79-
80-
// [NOTE] thread blocking:
81-
//
82-
// 1) prefer square block per thread
83-
// 2) use even number of CPU cores
84-
// 3) use all `num_threads` cores
85-
//
86-
// we have:
87-
// TM * TN = T
88-
// BM / TM = BN / TN
89-
// then:
90-
// TM = ((BM / BN) * T) ^ 0.5
91-
//
92-
float r = float(m) / n;
93-
int nth_m = std::ceil(std::sqrt(r * nth));
94-
int nth_n = 1;
95-
for (; nth_m > 0; --nth_m) {
96-
nth_n = nth / nth_m;
97-
if (nth_m * nth_n == nth) {
98-
break;
73+
template <typename func_t> inline void parallel_2d(int m, int n, const func_t& f) {
74+
// make sure we have even num_threads
75+
int nth = adjust_num_threads(m);
76+
77+
// [NOTE] thread blocking:
78+
//
79+
// 1) prefer square block per thread
80+
// 2) use even number of CPU cores
81+
// 3) use all `num_threads` cores
82+
//
83+
// we have:
84+
// TM * TN = T
85+
// BM / TM = BN / TN
86+
// then:
87+
// TM = ((BM / BN) * T) ^ 0.5
88+
//
89+
float r = float(m) / n;
90+
int nth_m = std::ceil(std::sqrt(r * nth));
91+
int nth_n = 1;
92+
for (; nth_m > 0; --nth_m) {
93+
nth_n = nth / nth_m;
94+
if (nth_m * nth_n == nth) {
95+
break;
96+
}
9997
}
100-
}
10198

10299
#if defined(_OPENMP)
103100
#pragma omp parallel num_threads(nth)
104-
{
105-
int ith = omp_get_thread_num();
106-
int ith_m = ith / nth_n;
107-
int ith_n = ith % nth_n;
101+
{
102+
int ith = omp_get_thread_num();
103+
int ith_m = ith / nth_n;
104+
int ith_n = ith % nth_n;
108105

109-
int thread_block_m = div_up(m, nth_m);
110-
int thread_block_n = div_up(n, nth_n);
106+
int thread_block_m = div_up(m, nth_m);
107+
int thread_block_n = div_up(n, nth_n);
111108

112-
int begin_m = ith_m * thread_block_m;
113-
int end_m = std::min(m, begin_m + thread_block_m);
114-
int begin_n = ith_n * thread_block_n;
115-
int end_n = std::min(n, begin_n + thread_block_n);
109+
int begin_m = ith_m * thread_block_m;
110+
int end_m = std::min(m, begin_m + thread_block_m);
111+
int begin_n = ith_n * thread_block_n;
112+
int end_n = std::min(n, begin_n + thread_block_n);
116113

117-
f(begin_m, end_m, begin_n, end_n);
118-
}
114+
f(begin_m, end_m, begin_n, end_n);
115+
}
119116
#else
120-
f(0, m, 0, n);
117+
f(0, m, 0, n);
121118
#endif
122119
}
123120

124-
125121
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n);
126122

127123
typedef enum DataType_t {
@@ -155,17 +151,17 @@ static inline fp16_t float_to_fp16(float x) {
155151
uint32_t bits;
156152
std::memcpy(&bits, &x, 4);
157153
uint32_t sign = (bits >> 31) & 0x1;
158-
uint32_t exp = (bits >> 23) & 0xFF;
154+
uint32_t exp = (bits >> 23) & 0xFF;
159155
uint32_t mant = bits & 0x7FFFFF;
160156

161157
uint16_t h;
162-
if (exp == 0xFF) { // Inf / NaN
158+
if (exp == 0xFF) { // Inf / NaN
163159
uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa
164160
h = (sign << 15) | (0x1F << 10) | mant16;
165-
} else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142)
161+
} else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142)
166162
h = (sign << 15) | (0x1F << 10); // Inf
167-
} else if (exp < 0x71) { // subnormal or zero (exp_f < 113)
168-
if (exp < 0x67) { // too small -> zero (exp_f < 103)
163+
} else if (exp < 0x71) { // subnormal or zero (exp_f < 113)
164+
if (exp < 0x67) { // too small -> zero (exp_f < 103)
169165
h = (sign << 15);
170166
} else {
171167
// subnormal: implicit leading 1
@@ -281,16 +277,22 @@ inline float dDequantizeNF4(unsigned char val) {
281277
return -1.0f; //*0000
282278
}
283279

284-
285280
template <typename T>
286-
void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n);
281+
void dequantizeBlockwise8bitCpu(
282+
float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n
283+
);
287284

288285
template <typename T, int DATA_TYPE>
289-
void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n);
286+
void dequantizeBlockwise4bitCpu(
287+
unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n
288+
);
290289

291290
#if defined(__AVX512F__) && defined(__AVX512BF16__)
292-
template <typename T, int DATA_TYPE>
293-
void gemv_4bit_inference(int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
291+
template <typename T, int DATA_TYPE>
292+
void gemv_4bit_inference(
293+
int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w,
294+
const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
295+
);
294296
#endif
295297

296298
#if defined(__AVX512F__)

csrc/pythonInterface.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -847,11 +847,13 @@ void cdequantize_blockwise_cpu_fp32(
847847
) {
848848
dequantizeBlockwise8bitCpu<float>(code, A, absmax, out, blocksize, n);
849849
}
850+
850851
void cdequantize_blockwise_cpu_bf16(
851852
float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n
852853
) {
853854
dequantizeBlockwise8bitCpu<bf16_t>(code, A, absmax, out, blocksize, n);
854855
}
856+
855857
void cdequantize_blockwise_cpu_fp16(
856858
float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n
857859
) {
@@ -863,11 +865,13 @@ void cdequantize_blockwise_cpu_fp4_fp32(
863865
) {
864866
dequantizeBlockwise4bitCpu<float, FP4>(A, absmax, out, blocksize, m, n);
865867
}
868+
866869
void cdequantize_blockwise_cpu_fp4_bf16(
867870
unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n
868871
) {
869872
dequantizeBlockwise4bitCpu<bf16_t, FP4>(A, absmax, out, blocksize, m, n);
870873
}
874+
871875
void cdequantize_blockwise_cpu_fp4_fp16(
872876
unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n
873877
) {
@@ -879,11 +883,13 @@ void cdequantize_blockwise_cpu_nf4_fp32(
879883
) {
880884
dequantizeBlockwise4bitCpu<float, NF4>(A, absmax, out, blocksize, m, n);
881885
}
886+
882887
void cdequantize_blockwise_cpu_nf4_bf16(
883888
unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n
884889
) {
885890
dequantizeBlockwise4bitCpu<bf16_t, NF4>(A, absmax, out, blocksize, m, n);
886891
}
892+
887893
void cdequantize_blockwise_cpu_nf4_fp16(
888894
unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n
889895
) {
@@ -892,24 +898,22 @@ void cdequantize_blockwise_cpu_nf4_fp16(
892898

893899
#if defined(__AVX512F__) && defined(__AVX512BF16__)
894900
void gemv_4bit_inference_cpu_fp4_bf16(
895-
int64_t M, int64_t N, int64_t K,
896-
const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
897-
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out,
898-
int64_t blocksize, int64_t x_stride, int64_t out_stride
901+
int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
902+
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
899903
) {
900904
gemv_4bit_inference<bf16_t, FP4>(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride);
901905
}
906+
902907
void gemv_4bit_inference_cpu_nf4_bf16(
903-
int64_t M, int64_t N, int64_t K,
904-
const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
905-
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out,
906-
int64_t blocksize, int64_t x_stride, int64_t out_stride
908+
int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
909+
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
907910
) {
908911
gemv_4bit_inference<bf16_t, NF4>(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride);
909912
}
910913
#endif
911914
#if defined(__AVX512F__)
912-
bool has_avx512f_cpu() return has_avx512f()
913-
bool has_avx512bf16_cpu() return has_avx512bf16()
915+
bool has_avx512f_cpu() { return has_avx512f() }
916+
917+
bool has_avx512bf16_cpu() { return has_avx512bf16() }
914918
#endif
915919
}

0 commit comments

Comments
 (0)