44
55using namespace BinSearch ;
66
7- void dequantize_cpu (float * code, unsigned char * A, float * absmax, float * out, long long blocksize, long long n) {
7+ void dequantize_cpu (float * code, unsigned char * A, float * absmax, float * out, long long blocksize, long long n) {
88 for (long long block_idx = 0 ; block_idx < n; block_idx += blocksize) {
99 long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
1010 long long block_end = block_idx + valid_items;
@@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo
1313 }
1414}
1515
16- void quantize_cpu (float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
17- {
16+ void quantize_cpu (float * code, float * A, float * absmax, unsigned char * out, long long blocksize, long long n) {
1817
1918 // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
2019 code[0 ] = -1 .0f ;
@@ -28,36 +27,35 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
2827 int thread_wave_size = 256 ;
2928 // we chunk the threads into waves of 256 since the max limit is
3029 // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
31- for (long long offset = 0 ; offset < num_blocks; offset+= thread_wave_size)
32- {
33- long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset ;
34- std::vector<std::thread> threads (valid_chunks);
35- std::vector<quantize_block_args> args (valid_chunks);
36-
37- int chunks_processed = 0 ;
38- for ( long long block_idx = offset* blocksize; block_idx < n; block_idx += blocksize)
39- {
40- long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
41- long long block_end = block_idx + valid_items ;
42-
43- struct quantize_block_args & arg = args[chunks_processed] ;
44- arg.bin_searcher = &bin_searcher ;
45- arg.code = code ;
46- arg.A = A ;
47- arg.absmax = absmax ;
48- arg.out = out ;
49- arg.block_end = block_end ;
50- arg.block_idx = block_idx ;
51- arg. threadidx = block_idx / blocksize;
52- arg. blocksize = blocksize ;
53-
54- threads[chunks_processed] = std::thread ([arg] { quantize_block (arg); });
55- chunks_processed += 1 ;
56- if (chunks_processed == valid_chunks){ break ; }
57- }
58-
59- for (int i = 0 ; i < valid_chunks; i++)
60- threads[i].join ();
30+ for (long long offset = 0 ; offset < num_blocks; offset += thread_wave_size) {
31+ long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
32+ std::vector<std::thread> threads (valid_chunks) ;
33+ std::vector<quantize_block_args> args (valid_chunks);
34+
35+ int chunks_processed = 0 ;
36+ for ( long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
37+ long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
38+ long long block_end = block_idx + valid_items;
39+
40+ struct quantize_block_args & arg = args[chunks_processed] ;
41+ arg. bin_searcher = &bin_searcher;
42+ arg. code = code ;
43+ arg.A = A ;
44+ arg.absmax = absmax ;
45+ arg.out = out ;
46+ arg.block_end = block_end ;
47+ arg.block_idx = block_idx ;
48+ arg.threadidx = block_idx / blocksize ;
49+ arg.blocksize = blocksize ;
50+
51+ threads[chunks_processed] = std::thread ([arg] { quantize_block (arg); }) ;
52+ chunks_processed += 1 ;
53+ if (chunks_processed == valid_chunks) {
54+ break ;
55+ }
56+ }
57+
58+ for (int i = 0 ; i < valid_chunks; i++)
59+ threads[i].join ();
6160 }
62-
6361}
0 commit comments