Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) {
if (idx < 255) {
float dist_left = fabs(normed_value - (args.code[idx]));
float dist_right = fabs(normed_value - (args.code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
if (dist_right < dist_left) {
idx += 1;
}
}

// 5. store index
args.out[i] = (unsigned char) idx;
args.out[i] = (unsigned char)idx;
}
}
55 changes: 28 additions & 27 deletions csrc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,48 @@

// TODO: Let's make some of these constexpr and put in a namespace.

#define BNB_CC_MAXWELL 500
#define BNB_CC_MAXWELL2 520
#define BNB_CC_MAXWELL2_X1 530
#define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620
#define BNB_CC_VOLTA 700
#define BNB_CC_VOLTA_XAVIER 720
#define BNB_CC_TURING 750
#define BNB_CC_AMPERE 800
#define BNB_CC_AMPERE2 860
#define BNB_CC_AMPERE2_ORIN 870
#define BNB_CC_ADA 890
#define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000
#define BNB_CC_MAXWELL 500
#define BNB_CC_MAXWELL2 520
#define BNB_CC_MAXWELL2_X1 530
#define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620
#define BNB_CC_VOLTA 700
#define BNB_CC_VOLTA_XAVIER 720
#define BNB_CC_TURING 750
#define BNB_CC_AMPERE 800
#define BNB_CC_AMPERE2 860
#define BNB_CC_AMPERE2_ORIN 870
#define BNB_CC_ADA 890
#define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000

#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)

#define BNB_WARP_SIZE 32
#define BNB_WARP_SIZE 32

// The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
// Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024
#define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
#define BNB_MAX_THREADS_PER_SM 1536
#define BNB_MAX_THREADS_PER_SM 1536
#else
#define BNB_MAX_THREADS_PER_SM 2048
#define BNB_MAX_THREADS_PER_SM 2048
#endif

// Maximum resident warps per SM is always directly related to the number of threads.
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))

// Maximum resident blocks per SM may vary.
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
#define BNB_MAX_BLOCKS_PER_SM 16
#define BNB_MAX_BLOCKS_PER_SM 16
#else
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
#endif
15 changes: 6 additions & 9 deletions csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,18 @@

using namespace BinSearch;

#define BLOCK_SIZE 16384

struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
BinAlgo<Scalar, float, Direct2>* bin_searcher;
float* code;
float* A;
float* absmax;
unsigned char* out;
long long block_end;
long long block_idx;
long long threadidx;
long long blocksize;
long long blocksize;
};


void quantize_block(const quantize_block_args& args);

#endif
66 changes: 32 additions & 34 deletions csrc/cpu_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using namespace BinSearch;

void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) {
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;
Expand All @@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo
}
}

void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
{
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {

// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
Expand All @@ -28,36 +27,35 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
int thread_wave_size = 256;
// we chunk the threads into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
std::vector<std::thread> threads(valid_chunks);
std::vector<quantize_block_args> args(valid_chunks);

int chunks_processed = 0;
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
{
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;

struct quantize_block_args& arg = args[chunks_processed];
arg.bin_searcher = &bin_searcher;
arg.code = code;
arg.A = A;
arg.absmax = absmax;
arg.out = out;
arg.block_end = block_end;
arg.block_idx = block_idx;
arg.threadidx = block_idx / blocksize;
arg.blocksize = blocksize;

threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; }
}

for (int i = 0; i < valid_chunks; i++)
threads[i].join();
for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
std::vector<std::thread> threads(valid_chunks);
std::vector<quantize_block_args> args(valid_chunks);

int chunks_processed = 0;
for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;

struct quantize_block_args& arg = args[chunks_processed];
arg.bin_searcher = &bin_searcher;
arg.code = code;
arg.A = A;
arg.absmax = absmax;
arg.out = out;
arg.block_end = block_end;
arg.block_idx = block_idx;
arg.threadidx = block_idx / blocksize;
arg.blocksize = blocksize;

threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
chunks_processed += 1;
if (chunks_processed == valid_chunks) {
break;
}
}

for (int i = 0; i < valid_chunks; i++)
threads[i].join();
}

}
4 changes: 2 additions & 2 deletions csrc/cpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <iostream>
#include <stdio.h>

void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n);
void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n);

#endif
Loading