Skip to content

Commit bdb25c0

Browse files
committed
disable binsearch
Signed-off-by: jiqing-feng <[email protected]>
1 parent 0045c4b commit bdb25c0

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

csrc/cpu_ops.cpp

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,40 @@
22
#include <cpu_ops.h>
33
#include <thread>
44

5+
#include <algorithm>
6+
#include <cmath>
7+
#include <vector>
8+
59
#ifdef HAS_OPENMP
610
#include <omp.h>
711
#define BNB_OMP_PARALLEL_FOR _Pragma("omp parallel for")
812
#else
913
#define BNB_OMP_PARALLEL_FOR
1014
#endif
1115

12-
using namespace BinSearch;
16+
namespace {
17+
18+
constexpr int kCodebookSize = 256;
19+
20+
inline unsigned char lookup_code_index(const float* codebook, float value) {
21+
value = std::clamp(value, -1.0f, 1.0f);
22+
const float* begin = codebook;
23+
const float* end = codebook + kCodebookSize;
24+
const float* right = std::lower_bound(begin, end, value);
25+
if (right == begin) {
26+
return 0;
27+
}
28+
if (right == end) {
29+
return static_cast<unsigned char>(kCodebookSize - 1);
30+
}
31+
const float* left = right - 1;
32+
const float dist_left = std::fabs(value - *left);
33+
const float dist_right = std::fabs(*right - value);
34+
const unsigned char idx = static_cast<unsigned char>(right - begin);
35+
return dist_right < dist_left ? idx : idx - 1;
36+
}
37+
38+
}
1339

1440
#if defined(__AVX512F__)
1541
#include <immintrin.h>
@@ -181,48 +207,57 @@ void dequantizeBlockwise8bitCpu(
181207

182208
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {
183209

184-
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
210+
if (blocksize <= 0 || n <= 0)
211+
return;
212+
213+
// Ensure we cover the full expected dynamic range of the codebook.
185214
code[0] = -1.0f;
186215

187-
long long num_blocks = n / blocksize;
188-
num_blocks += n % blocksize == 0 ? 0 : 1;
216+
const auto process_block = [&](long long block_start, long long block_end) {
217+
float absmax_block = 0.0f;
218+
for (long long i = block_start; i < block_end; ++i) {
219+
absmax_block = std::max(absmax_block, std::fabs(A[i]));
220+
}
221+
222+
long long absmax_idx = block_start / blocksize;
223+
absmax[absmax_idx] = absmax_block;
224+
225+
if (absmax_block == 0.0f) {
226+
std::fill(out + block_start, out + block_end, 0);
227+
return;
228+
}
189229

190-
const uint32 elements_code = 256;
191-
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
230+
const float inv_absmax = 1.0f / absmax_block;
231+
for (long long i = block_start; i < block_end; ++i) {
232+
float normed_value = A[i] * inv_absmax;
233+
out[i] = lookup_code_index(code, normed_value);
234+
}
235+
};
236+
237+
const long long num_blocks = (n + blocksize - 1) / blocksize;
238+
const int thread_wave_size = 256;
192239

193-
int thread_wave_size = 256;
194-
// we chunk the threads into waves of 256 since the max limit is
195-
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
240+
// We chunk the threads into waves of 256 since the max limit is between 16k and 64k on Linux
241+
// (we reach this when running BLOOM-176B with a large batch size).
196242
for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
197-
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
198-
std::vector<std::thread> threads(valid_chunks);
199-
std::vector<quantize_block_args> args(valid_chunks);
200-
201-
int chunks_processed = 0;
202-
for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
203-
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
204-
long long block_end = block_idx + valid_items;
205-
206-
struct quantize_block_args& arg = args[chunks_processed];
207-
arg.bin_searcher = &bin_searcher;
208-
arg.code = code;
209-
arg.A = A;
210-
arg.absmax = absmax;
211-
arg.out = out;
212-
arg.block_end = block_end;
213-
arg.block_idx = block_idx;
214-
arg.threadidx = block_idx / blocksize;
215-
arg.blocksize = blocksize;
216-
217-
threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
218-
chunks_processed += 1;
219-
if (chunks_processed == valid_chunks) {
243+
const long long wave_blocks = std::min<long long>(thread_wave_size, num_blocks - offset);
244+
std::vector<std::thread> threads;
245+
threads.reserve(wave_blocks);
246+
247+
const long long first_block_start = offset * blocksize;
248+
for (long long b = 0; b < wave_blocks; ++b) {
249+
const long long block_start = first_block_start + b * blocksize;
250+
if (block_start >= n)
220251
break;
221-
}
252+
const long long block_end = std::min(block_start + blocksize, n);
253+
threads.emplace_back(process_block, block_start, block_end);
222254
}
223255

224-
for (int i = 0; i < valid_chunks; i++)
225-
threads[i].join();
256+
for (auto& thread : threads) {
257+
if (thread.joinable()) {
258+
thread.join();
259+
}
260+
}
226261
}
227262
}
228263

0 commit comments

Comments
 (0)