11#ifndef BITSANDBYTES_CPU_OPS_H
22#define BITSANDBYTES_CPU_OPS_H
33
4+ #include < algorithm>
5+ #include < cmath>
46#include < cstdint>
57#include < cstring>
68#include < thread>
7- #include < cmath>
8- #include < algorithm>
99#include < type_traits>
1010
1111#if defined(_OPENMP)
12- #include < omp.h>
12+ #include < omp.h>
1313#endif
1414
1515// amx-bf16
1616#define TILE_M 16
1717#define TILE_N 16
1818#define TILE_K 32
1919// work around compiler internal error
20- #define BLOCK_K 128 // 4 * TILE_K
20+ #define BLOCK_K 128 // 4 * TILE_K
2121
2222// block size for AMX gemm
2323constexpr int block_size_m () { return 2 * TILE_M; }
24+
2425constexpr int block_size_n () { return 2 * TILE_N; }
2526
26- template <typename T>
27- inline int get_cache_blocks (int chunk_size) {
28- // L2 2MB and ratio of 50%
29- const int L2_size = 2048 * 1024 >> 1 ;
30- return std::max (1 , int (L2_size / (chunk_size * sizeof (T))));
27+ template <typename T> inline int get_cache_blocks (int chunk_size) {
28+ // L2 2MB and ratio of 50%
29+ const int L2_size = 2048 * 1024 >> 1 ;
30+ return std::max (1 , int (L2_size / (chunk_size * sizeof (T))));
3131}
3232
3333// forced unroll for perf critical path
@@ -37,25 +37,22 @@ inline int get_cache_blocks(int chunk_size) {
3737#define ALWAYS_INLINE inline
3838#endif
3939
40- template <int n>
41- struct Unroll {
42- template <typename Func, typename ... Args>
43- ALWAYS_INLINE void operator ()(const Func& f, Args... args) const {
44- Unroll<n - 1 >{}(f, args...);
45- f (std::integral_constant<int , n - 1 >{}, args...);
46- }
40+ template <int n> struct Unroll {
41+ template <typename Func, typename ... Args> ALWAYS_INLINE void operator ()(const Func& f, Args... args) const {
42+ Unroll<n - 1 >{}(f, args...);
43+ f (std::integral_constant<int , n - 1 >{}, args...);
44+ }
4745};
4846
49- template <>
50- struct Unroll <1 > {
51- template <typename Func, typename ... Args>
52- ALWAYS_INLINE void operator ()(const Func& f, Args... args) const {
53- f (std::integral_constant<int , 0 >{}, args...);
54- }
47+ template <> struct Unroll <1 > {
48+ template <typename Func, typename ... Args> ALWAYS_INLINE void operator ()(const Func& f, Args... args) const {
49+ f (std::integral_constant<int , 0 >{}, args...);
50+ }
5551};
5652
57- template <typename T, typename std::enable_if<std::is_integral<T>::value, int >::type = 0 >
58- inline T div_up (T x, T y) { return (x + y - 1 ) / y; }
53+ template <typename T, typename std::enable_if<std::is_integral<T>::value, int >::type = 0 > inline T div_up (T x, T y) {
54+ return (x + y - 1 ) / y;
55+ }
5956
6057inline int get_max_threads () {
6158#if defined(_OPENMP)
@@ -68,60 +65,59 @@ inline int get_max_threads() {
6865
6966inline int adjust_num_threads (int m) {
7067 int actual_nth = get_max_threads ();
71- if (m == 1 ) return actual_nth;
68+ if (m == 1 )
69+ return actual_nth;
7270 return std::max (1 , (actual_nth >> 1 ) * 2 );
7371}
7472
75- template <typename func_t >
76- inline void parallel_2d ( int m, int n, const func_t & f) {
77- // make sure we have even num_threads
78- int nth = adjust_num_threads (m);
79-
80- // [NOTE] thread blocking:
81- //
82- // 1) prefer square block per thread
83- // 2 ) use even number of CPU cores
84- // 3) use all `num_threads` cores
85- //
86- // we have:
87- // TM * TN = T
88- // BM / TM = BN / TN
89- // then:
90- // TM = ((BM / BN) * T) ^ 0.5
91- //
92- float r = float (m) / n ;
93- int nth_m = std::ceil ( std::sqrt (r * nth)) ;
94- int nth_n = 1 ;
95- for (; nth_m > 0 ; --nth_m) {
96- nth_n = nth / nth_m;
97- if (nth_m * nth_n == nth) {
98- break ;
73+ template <typename func_t > inline void parallel_2d ( int m, int n, const func_t & f) {
74+ // make sure we have even num_threads
75+ int nth = adjust_num_threads (m);
76+
77+ // [NOTE] thread blocking:
78+ //
79+ // 1) prefer square block per thread
80+ // 2) use even number of CPU cores
81+ // 3 ) use all `num_threads` cores
82+ //
83+ // we have:
84+ // TM * TN = T
85+ // BM / TM = BN / TN
86+ // then:
87+ // TM = ((BM / BN) * T) ^ 0.5
88+ //
89+ float r = float (m) / n;
90+ int nth_m = std::ceil ( std::sqrt (r * nth)) ;
91+ int nth_n = 1 ;
92+ for (; nth_m > 0 ; --nth_m) {
93+ nth_n = nth / nth_m;
94+ if (nth_m * nth_n == nth) {
95+ break ;
96+ }
9997 }
100- }
10198
10299#if defined(_OPENMP)
103100#pragma omp parallel num_threads(nth)
104- {
105- int ith = omp_get_thread_num ();
106- int ith_m = ith / nth_n;
107- int ith_n = ith % nth_n;
101+ {
102+ int ith = omp_get_thread_num ();
103+ int ith_m = ith / nth_n;
104+ int ith_n = ith % nth_n;
108105
109- int thread_block_m = div_up (m, nth_m);
110- int thread_block_n = div_up (n, nth_n);
106+ int thread_block_m = div_up (m, nth_m);
107+ int thread_block_n = div_up (n, nth_n);
111108
112- int begin_m = ith_m * thread_block_m;
113- int end_m = std::min (m, begin_m + thread_block_m);
114- int begin_n = ith_n * thread_block_n;
115- int end_n = std::min (n, begin_n + thread_block_n);
109+ int begin_m = ith_m * thread_block_m;
110+ int end_m = std::min (m, begin_m + thread_block_m);
111+ int begin_n = ith_n * thread_block_n;
112+ int end_n = std::min (n, begin_n + thread_block_n);
116113
117- f (begin_m, end_m, begin_n, end_n);
118- }
114+ f (begin_m, end_m, begin_n, end_n);
115+ }
119116#else
120- f (0 , m, 0 , n);
117+ f (0 , m, 0 , n);
121118#endif
122119}
123120
124-
125121void quantize_cpu (float * code, float * A, float * absmax, unsigned char * out, long long blocksize, long long n);
126122
127123typedef enum DataType_t {
@@ -155,17 +151,17 @@ static inline fp16_t float_to_fp16(float x) {
155151 uint32_t bits;
156152 std::memcpy (&bits, &x, 4 );
157153 uint32_t sign = (bits >> 31 ) & 0x1 ;
158- uint32_t exp = (bits >> 23 ) & 0xFF ;
154+ uint32_t exp = (bits >> 23 ) & 0xFF ;
159155 uint32_t mant = bits & 0x7FFFFF ;
160156
161157 uint16_t h;
162- if (exp == 0xFF ) { // Inf / NaN
158+ if (exp == 0xFF ) { // Inf / NaN
163159 uint16_t mant16 = mant ? 0x200 : 0 ; // quiet NaN: set MSB of mantissa
164160 h = (sign << 15 ) | (0x1F << 10 ) | mant16;
165- } else if (exp > 0x70 + 0x1E ) { // overflow: exp_f -127 +15 > 30 (exp_f > 142)
161+ } else if (exp > 0x70 + 0x1E ) { // overflow: exp_f -127 +15 > 30 (exp_f > 142)
166162 h = (sign << 15 ) | (0x1F << 10 ); // Inf
167- } else if (exp < 0x71 ) { // subnormal or zero (exp_f < 113)
168- if (exp < 0x67 ) { // too small -> zero (exp_f < 103)
163+ } else if (exp < 0x71 ) { // subnormal or zero (exp_f < 113)
164+ if (exp < 0x67 ) { // too small -> zero (exp_f < 103)
169165 h = (sign << 15 );
170166 } else {
171167 // subnormal: implicit leading 1
@@ -281,16 +277,22 @@ inline float dDequantizeNF4(unsigned char val) {
281277 return -1 .0f ; // *0000
282278}
283279
284-
285280template <typename T>
286- void dequantizeBlockwise8bitCpu (float * code, unsigned char * A, const float * absmax, T* out, long long blocksize, long long n);
281+ void dequantizeBlockwise8bitCpu (
282+ float * code, unsigned char * A, const float * absmax, T* out, long long blocksize, long long n
283+ );
287284
288285template <typename T, int DATA_TYPE>
289- void dequantizeBlockwise4bitCpu (unsigned char * A, const float * absmax, T* out, long long blocksize, long long m, long long n);
286+ void dequantizeBlockwise4bitCpu (
287+ unsigned char * A, const float * absmax, T* out, long long blocksize, long long m, long long n
288+ );
290289
291290#if defined(__AVX512F__) && defined(__AVX512BF16__)
292- template <typename T, int DATA_TYPE>
293- void gemv_4bit_inference (int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char * __restrict__ w, const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
291+ template <typename T, int DATA_TYPE>
292+ void gemv_4bit_inference (
293+ int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char * __restrict__ w,
294+ const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
295+ );
294296#endif
295297
296298#if defined(__AVX512F__)
0 commit comments