|
2 | 2 | #include <cpu_ops.h> |
3 | 3 | #include <thread> |
4 | 4 |
|
| 5 | +#include <algorithm> |
| 6 | +#include <cmath> |
| 7 | +#include <vector> |
| 8 | + |
5 | 9 | #ifdef HAS_OPENMP |
6 | 10 | #include <omp.h> |
7 | 11 | #define BNB_OMP_PARALLEL_FOR _Pragma("omp parallel for") |
8 | 12 | #else |
9 | 13 | #define BNB_OMP_PARALLEL_FOR |
10 | 14 | #endif |
11 | 15 |
|
12 | | -using namespace BinSearch; |
| 16 | +namespace { |
| 17 | + |
| 18 | +constexpr int kCodebookSize = 256; |
| 19 | + |
| 20 | +inline unsigned char lookup_code_index(const float* codebook, float value) { |
| 21 | + value = std::clamp(value, -1.0f, 1.0f); |
| 22 | + const float* begin = codebook; |
| 23 | + const float* end = codebook + kCodebookSize; |
| 24 | + const float* right = std::lower_bound(begin, end, value); |
| 25 | + if (right == begin) { |
| 26 | + return 0; |
| 27 | + } |
| 28 | + if (right == end) { |
| 29 | + return static_cast<unsigned char>(kCodebookSize - 1); |
| 30 | + } |
| 31 | + const float* left = right - 1; |
| 32 | + const float dist_left = std::fabs(value - *left); |
| 33 | + const float dist_right = std::fabs(*right - value); |
| 34 | + const unsigned char idx = static_cast<unsigned char>(right - begin); |
| 35 | + return dist_right < dist_left ? idx : idx - 1; |
| 36 | +} |
| 37 | + |
| 38 | +} |
13 | 39 |
|
14 | 40 | #if defined(__AVX512F__) |
15 | 41 | #include <immintrin.h> |
@@ -181,48 +207,57 @@ void dequantizeBlockwise8bitCpu( |
181 | 207 |
|
182 | 208 | void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { |
183 | 209 |
|
184 | | - // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below |
| 210 | + if (blocksize <= 0 || n <= 0) |
| 211 | + return; |
| 212 | + |
| 213 | + // Ensure we cover the full expected dynamic range of the codebook. |
185 | 214 | code[0] = -1.0f; |
186 | 215 |
|
187 | | - long long num_blocks = n / blocksize; |
188 | | - num_blocks += n % blocksize == 0 ? 0 : 1; |
| 216 | + const auto process_block = [&](long long block_start, long long block_end) { |
| 217 | + float absmax_block = 0.0f; |
| 218 | + for (long long i = block_start; i < block_end; ++i) { |
| 219 | + absmax_block = std::max(absmax_block, std::fabs(A[i])); |
| 220 | + } |
| 221 | + |
| 222 | + long long absmax_idx = block_start / blocksize; |
| 223 | + absmax[absmax_idx] = absmax_block; |
| 224 | + |
| 225 | + if (absmax_block == 0.0f) { |
| 226 | + std::fill(out + block_start, out + block_end, 0); |
| 227 | + return; |
| 228 | + } |
189 | 229 |
|
190 | | - const uint32 elements_code = 256; |
191 | | - BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code); |
| 230 | + const float inv_absmax = 1.0f / absmax_block; |
| 231 | + for (long long i = block_start; i < block_end; ++i) { |
| 232 | + float normed_value = A[i] * inv_absmax; |
| 233 | + out[i] = lookup_code_index(code, normed_value); |
| 234 | + } |
| 235 | + }; |
| 236 | + |
| 237 | + const long long num_blocks = (n + blocksize - 1) / blocksize; |
| 238 | + const int thread_wave_size = 256; |
192 | 239 |
|
193 | | - int thread_wave_size = 256; |
194 | | - // we chunk the threads into waves of 256 since the max limit is |
195 | | - // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) |
| 240 | + // We chunk the threads into waves of 256 since the max limit is between 16k and 64k on Linux |
| 241 | + // (we reach this when running BLOOM-176B with a large batch size). |
196 | 242 | for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) { |
197 | | - long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; |
198 | | - std::vector<std::thread> threads(valid_chunks); |
199 | | - std::vector<quantize_block_args> args(valid_chunks); |
200 | | - |
201 | | - int chunks_processed = 0; |
202 | | - for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) { |
203 | | - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; |
204 | | - long long block_end = block_idx + valid_items; |
205 | | - |
206 | | - struct quantize_block_args& arg = args[chunks_processed]; |
207 | | - arg.bin_searcher = &bin_searcher; |
208 | | - arg.code = code; |
209 | | - arg.A = A; |
210 | | - arg.absmax = absmax; |
211 | | - arg.out = out; |
212 | | - arg.block_end = block_end; |
213 | | - arg.block_idx = block_idx; |
214 | | - arg.threadidx = block_idx / blocksize; |
215 | | - arg.blocksize = blocksize; |
216 | | - |
217 | | - threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); |
218 | | - chunks_processed += 1; |
219 | | - if (chunks_processed == valid_chunks) { |
| 243 | + const long long wave_blocks = std::min<long long>(thread_wave_size, num_blocks - offset); |
| 244 | + std::vector<std::thread> threads; |
| 245 | + threads.reserve(wave_blocks); |
| 246 | + |
| 247 | + const long long first_block_start = offset * blocksize; |
| 248 | + for (long long b = 0; b < wave_blocks; ++b) { |
| 249 | + const long long block_start = first_block_start + b * blocksize; |
| 250 | + if (block_start >= n) |
220 | 251 | break; |
221 | | - } |
| 252 | + const long long block_end = std::min(block_start + blocksize, n); |
| 253 | + threads.emplace_back(process_block, block_start, block_end); |
222 | 254 | } |
223 | 255 |
|
224 | | - for (int i = 0; i < valid_chunks; i++) |
225 | | - threads[i].join(); |
| 256 | + for (auto& thread : threads) { |
| 257 | + if (thread.joinable()) { |
| 258 | + thread.join(); |
| 259 | + } |
| 260 | + } |
226 | 261 | } |
227 | 262 | } |
228 | 263 |
|
|
0 commit comments