Skip to content

Commit c059bd2

Browse files
committed
Added additional blocksizes: {64, 128, 256}.
1 parent eb028e6 commit c059bd2

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

bitsandbytes/functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
503503
out = torch.zeros_like(A, dtype=torch.uint8)
504504

505505
if A.device.type != 'cpu':
506-
assert blocksize in [4096, 2048, 1024, 512]
506+
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
507507
cblocksize = ct.c_int32(blocksize)
508508
prev_device = pre_call(A.device)
509509
code = code.to(A.device)
@@ -586,8 +586,8 @@ def dequantize_blockwise(
586586
if A.device.type != 'cpu':
587587
device = pre_call(A.device)
588588
code = code.to(A.device)
589-
if blocksize not in [2048, 4096, 1024, 512]:
590-
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
589+
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
590+
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
591591
is_on_gpu([A, out])
592592
if out.dtype == torch.float32:
593593
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))

csrc/kernels.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
454454
__shared__ float smem_code[256];
455455
__shared__ float smem_absmax_value[1];
456456

457-
if(threadIdx.x < 256)
458-
smem_code[threadIdx.x] = code[threadIdx.x];
457+
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
458+
smem_code[i] = code[i];
459459

460460
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
461461
{
@@ -2799,6 +2799,12 @@ template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half
27992799
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28002800
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28012801
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2802+
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2803+
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2804+
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2805+
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2806+
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2807+
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28022808

28032809
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
28042810
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
@@ -2808,6 +2814,12 @@ template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, u
28082814
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
28092815
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
28102816
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2817+
template __global__ void kDequantizeBlockwise<half, 256, 128, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2818+
template __global__ void kDequantizeBlockwise<float, 256, 128, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2819+
template __global__ void kDequantizeBlockwise<half, 128, 64, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2820+
template __global__ void kDequantizeBlockwise<float, 128, 64, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2821+
template __global__ void kDequantizeBlockwise<half, 64, 64, 1>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2822+
template __global__ void kDequantizeBlockwise<float, 64, 64, 1>(float *code, unsigned char * A, float * absmax, float *out, const int n);
28112823

28122824

28132825

csrc/ops.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
6565
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
6666
else if(blocksize == 512)
6767
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
68+
else if(blocksize == 256)
69+
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
70+
else if(blocksize == 128)
71+
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
72+
else if(blocksize == 64)
73+
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
6874

6975

7076
CUDA_CHECK_RETURN(cudaPeekAtLastError());
@@ -82,6 +88,12 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
8288
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
8389
else if(blocksize == 512)
8490
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
91+
else if(blocksize == 256)
92+
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
93+
else if(blocksize == 128)
94+
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
95+
else if(blocksize == 64)
96+
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
8597

8698
CUDA_CHECK_RETURN(cudaPeekAtLastError());
8799
}

0 commit comments

Comments
 (0)