Skip to content

Commit 08fa2e7

Browse files
committed
Fixed bug in cpu quant; faster GPU dequant.
1 parent 62a333a commit 08fa2e7

File tree

5 files changed

+44
-25
lines changed

5 files changed

+44
-25
lines changed

bitsandbytes/cextension.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def run_cuda_setup(self):
9494
else:
9595
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
9696
self.lib = ct.cdll.LoadLibrary(binary_path)
97-
print(self.lib)
9897
except Exception as ex:
9998
self.add_log_entry(str(ex))
10099
self.print_log_stack()

bitsandbytes/functional.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -458,16 +458,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
458458
"""
459459

460460

461-
prev_device = pre_call(A.device)
462461
if code is None:
463462
if "dynamic" not in name2qmap:
464463
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
465464
code = name2qmap["dynamic"]
466-
code = code.to(A.device)
467465

468466
if absmax is None:
469467
n = A.numel()
470-
blocksize = (blocksize if A.device.type == 'cuda' else 4096)
471468
blocks = n // blocksize
472469
blocks += 1 if n % blocksize > 0 else 0
473470
absmax = torch.zeros((blocks,), device=A.device)
@@ -477,8 +474,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
477474

478475
if A.device.type != 'cpu':
479476
assert blocksize in [4096, 2048, 1024, 512]
480-
is_on_gpu([code, A, absmax, out, rand])
481477
cblocksize = ct.c_int32(blocksize)
478+
prev_device = pre_call(A.device)
479+
code = code.to(A.device)
482480
if rand is not None:
483481
is_on_gpu([code, A, out, absmax, rand])
484482
assert blocksize==4096
@@ -498,11 +496,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
498496
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
499497
else:
500498
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
499+
post_call(A.device)
501500
else:
502501
# cpu
502+
code = code.cpu()
503503
assert rand is None
504504
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
505-
post_call(A.device)
506505

507506
return out, (absmax, code)
508507

@@ -541,32 +540,35 @@ def dequantize_blockwise(
541540
Dequantized tensor (default: float32)
542541
"""
543542
assert quant_state is not None or absmax is not None
544-
device = pre_call(A.device)
545543
if code is None and quant_state is None:
546544
if "dynamic" not in name2qmap:
547545
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
548546
code = name2qmap["dynamic"]
549-
code = code.to(A.device)
550547

551548
if out is None:
552549
out = torch.zeros_like(A, dtype=torch.float32)
553550
if quant_state is None:
554551
quant_state = (absmax, code)
552+
else:
553+
absmax, code = quant_state
555554

556555

557556
if A.device.type != 'cpu':
557+
device = pre_call(A.device)
558+
code = code.to(A.device)
558559
if blocksize not in [2048, 4096, 1024, 512]:
559560
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
560561
is_on_gpu([A, out])
561562
if out.dtype == torch.float32:
562-
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
563+
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
563564
elif out.dtype == torch.float16:
564-
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
565+
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
565566
else:
566567
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
568+
post_call(A.device)
567569
else:
570+
code = code.cpu()
568571
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
569-
post_call(A.device)
570572

571573
return out
572574

csrc/kernels.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
510510
}
511511

512512
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
513-
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
513+
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
514514
{
515515

516516
const int n_full = gridDim.x * BLOCK_SIZE;
@@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
526526

527527
__shared__ typename LoadChar::TempStorage loadchar;
528528
__shared__ typename StoreT::TempStorage storet;
529-
__shared__ float smem_code[256];
529+
//__shared__ float smem_code[256];
530+
//float local_code[16];
530531

531-
if(threadIdx.x < 256)
532-
smem_code[threadIdx.x] = code[threadIdx.x];
532+
//if(threadIdx.x < 256)
533+
//smem_code[threadIdx.x] = code[threadIdx.x];
533534

534535
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
535536
{
@@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
539540
__syncthreads();
540541
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
541542

543+
// load code through read-only cache via __ldg
542544
#pragma unroll NUM_PER_TH
543545
for(int j = 0; j < NUM_PER_TH; j++)
544-
vals[j] = smem_code[qvals[j]]*local_abs_max;
546+
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
545547

546548
__syncthreads();
547549
StoreT(storet).Store(&(out[i]), vals, valid_items);
@@ -2798,14 +2800,14 @@ template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, flo
27982800
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
27992801
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28002802

2801-
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
2802-
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
2803-
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
2804-
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
2805-
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
2806-
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
2807-
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
2808-
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
2803+
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2804+
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2805+
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2806+
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2807+
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2808+
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2809+
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2810+
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
28092811

28102812

28112813

csrc/kernels.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
1515
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
1616

1717
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
18-
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
18+
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
1919

2020
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
2121
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,

tests/test_functional.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,3 +2166,19 @@ def test_kbit_quantile_estimation():
21662166
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
21672167
err = torch.abs(val1-val2).mean()
21682168
assert err < 0.035
2169+
2170+
2171+
def test_bench_dequantization():
2172+
a = torch.rand(1024, 1024, device='cuda').half()
2173+
qa, SA = F.quantize_blockwise(a)
2174+
2175+
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
2176+
#print(max_theoretical_mu)
2177+
2178+
torch.cuda.synchronize()
2179+
t0 = time.time()
2180+
for i in range(100):
2181+
F.dequantize_blockwise(qa, SA, blocksize=2048)
2182+
torch.cuda.synchronize()
2183+
#print((time.time()-t0)/1e6)
2184+

0 commit comments

Comments
 (0)