Skip to content

Commit b104ce3

Browse files
authored
Merge branch 'main' into cleanup
2 parents 62c0bd2 + 08fa2e7 commit b104ce3

File tree

8 files changed

+361
-118
lines changed

8 files changed

+361
-118
lines changed

bitsandbytes/cextension.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ def generate_instructions(self):
5252
self.add_log_entry('python setup.py install')
5353

5454
def initialize(self):
55-
self.cuda_setup_log = []
55+
self.has_printed = False
5656
self.lib = None
57+
self.run_cuda_setup()
58+
59+
def run_cuda_setup(self):
60+
self.initialized = True
61+
self.cuda_setup_log = []
5762

5863
from .cuda_setup.main import evaluate_cuda_setup
5964
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
@@ -89,7 +94,8 @@ def initialize(self):
8994
else:
9095
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
9196
self.lib = ct.cdll.LoadLibrary(binary_path)
92-
except:
97+
except Exception as ex:
98+
self.add_log_entry(str(ex))
9399
self.print_log_stack()
94100

95101
def add_log_entry(self, msg, is_warning=False):

bitsandbytes/functional.py

Lines changed: 114 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import ctypes as ct
6+
import itertools
67
import operator
78
import random
89
from functools import reduce # Required in Python 3
@@ -130,13 +131,59 @@ def get_instance(cls):
130131
return cls._instance
131132

132133

133-
def create_linear_map(signed=True):
134-
if signed:
135-
return torch.linspace(-1.0, 1.0, 256)
136-
return torch.linspace(0.0, 1.0, 256)
134+
def create_linear_map(signed=True, total_bits=8):
135+
sign = (-1.0 if signed else 0.0)
137136

138-
139-
def create_dynamic_map(signed=True, n=7):
137+
values = torch.linspace(sign, 1.0, 2**total_bits)
138+
gap = 256 - values.numel()
139+
if gap == 0:
140+
return values
141+
else:
142+
l = values.numel()//2
143+
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
144+
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
145+
146+
147+
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
148+
e = exponent_bits
149+
p = precision_bits
150+
has_sign = 1 if signed else 0
151+
assert e+p == total_bits-has_sign
152+
# the exponent is biased to 2^(e-1) -1 == 0
153+
evalues = []
154+
pvalues = []
155+
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
156+
evalues.append(2**val)
157+
158+
159+
lst = list(itertools.product([0, 1], repeat=precision_bits))
160+
for bit_pattern in lst:
161+
value = 1
162+
for i, pval in enumerate(list(bit_pattern)):
163+
value += pval*(2**-(i+1))
164+
pvalues.append(value)
165+
166+
assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
167+
values = []
168+
for ev in evalues:
169+
for pv in pvalues:
170+
if signed:
171+
values.append(-ev*pv)
172+
values.append(ev*pv)
173+
if total_bits < 8:
174+
gap = 256 - len(values)
175+
for i in range(gap):
176+
values.append(0)
177+
values.sort()
178+
code = torch.Tensor(values)
179+
code /= code.max()
180+
code[127] = 0
181+
182+
return code
183+
184+
185+
186+
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
140187
"""
141188
Creates the dynamic quantiztion map.
142189
@@ -157,28 +204,32 @@ def create_dynamic_map(signed=True, n=7):
157204
# these are additional items that come from the case
158205
# where all the exponent bits are zero and no
159206
# indicator bit is present
160-
additional_items = 2 ** (7 - n) - 1
207+
non_sign_bits = total_bits - (1 if signed else 0)
208+
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
161209
if not signed:
162210
additional_items = 2 * additional_items
163-
for i in range(n):
164-
fraction_items = (
165-
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
166-
)
211+
for i in range(max_exponent_bits):
212+
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
167213
boundaries = torch.linspace(0.1, 1, fraction_items)
168214
means = (boundaries[:-1] + boundaries[1:]) / 2.0
169-
data += ((10 ** (-(n - 1) + i)) * means).tolist()
215+
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
170216
if signed:
171-
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
217+
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
172218

173-
if additional_items > 0:
174-
boundaries = torch.linspace(0.1, 1, additional_items + 1)
175-
means = (boundaries[:-1] + boundaries[1:]) / 2.0
176-
data += ((10 ** (-(n - 1) + i)) * means).tolist()
177-
if signed:
178-
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
219+
if additional_items > 0:
220+
boundaries = torch.linspace(0.1, 1, additional_items + 1)
221+
means = (boundaries[:-1] + boundaries[1:]) / 2.0
222+
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
223+
if signed:
224+
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
179225

180226
data.append(0)
181227
data.append(1.0)
228+
229+
gap = 256 - len(data)
230+
for i in range(gap):
231+
data.append(0)
232+
182233
data.sort()
183234
return Tensor(data)
184235

@@ -322,9 +373,7 @@ def nvidia_transform(
322373
return out, new_state
323374

324375

325-
def estimate_quantiles(
326-
A: Tensor, out: Tensor = None, offset: float = 1 / 512
327-
) -> Tensor:
376+
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
328377
'''
329378
Estimates 256 equidistant quantiles on the input tensor eCDF.
330379
@@ -344,25 +393,36 @@ def estimate_quantiles(
344393
out : torch.Tensor
345394
Tensor with the 256 estimated quantiles.
346395
offset : float
347-
The offset for the first and last quantile from 0 and 1. Default: 1/512
396+
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
397+
num_quantiles : int
398+
The number of equally spaced quantiles.
348399
349400
Returns
350401
-------
351402
torch.Tensor:
352403
The 256 quantiles in float32 datatype.
353404
'''
405+
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
406+
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
407+
if num_quantiles < 256 and offset == 1/(512):
408+
# override default arguments
409+
offset = 1/(2*num_quantiles)
410+
354411
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
355412
is_on_gpu([A, out])
413+
device = pre_call(A.device)
356414
if A.dtype == torch.float32:
357-
lib.cestimate_quantiles_fp32(
358-
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
359-
)
415+
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
360416
elif A.dtype == torch.float16:
361-
lib.cestimate_quantiles_fp16(
362-
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
363-
)
417+
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
364418
else:
365419
raise NotImplementedError(f"Not supported data type {A.dtype}")
420+
post_call(device)
421+
422+
if num_quantiles < 256:
423+
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
424+
out = out[idx]
425+
366426
return out
367427

368428

@@ -395,15 +455,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
395455
The quantization state to undo the quantization.
396456
"""
397457

458+
398459
if code is None:
399460
if "dynamic" not in name2qmap:
400461
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
401462
code = name2qmap["dynamic"]
402-
code = code.to(A.device)
403463

404464
if absmax is None:
405465
n = A.numel()
406-
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
407466
blocks = n // blocksize
408467
blocks += 1 if n % blocksize > 0 else 0
409468
absmax = torch.zeros((blocks,), device=A.device)
@@ -412,29 +471,33 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
412471
out = torch.zeros_like(A, dtype=torch.uint8)
413472

414473
if A.device.type != 'cpu':
415-
is_on_gpu([code, A, absmax, out, rand])
474+
assert blocksize in [4096, 2048, 1024, 512]
475+
cblocksize = ct.c_int32(blocksize)
476+
prev_device = pre_call(A.device)
477+
code = code.to(A.device)
416478
if rand is not None:
479+
is_on_gpu([code, A, out, absmax, rand])
480+
assert blocksize==4096
417481
assert rand.numel() >= 1024
418482
rand_offset = random.randint(0, 1023)
419483
if A.dtype == torch.float32:
420484
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
421485
elif A.dtype == torch.float16:
422486
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
423487
else:
424-
raise ValueError(
425-
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
426-
)
488+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
427489
else:
490+
is_on_gpu([code, A, out, absmax])
428491
if A.dtype == torch.float32:
429-
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
492+
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
430493
elif A.dtype == torch.float16:
431-
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
494+
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
432495
else:
433-
raise ValueError(
434-
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
435-
)
496+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
497+
post_call(A.device)
436498
else:
437499
# cpu
500+
code = code.cpu()
438501
assert rand is None
439502
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
440503

@@ -479,27 +542,30 @@ def dequantize_blockwise(
479542
if "dynamic" not in name2qmap:
480543
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
481544
code = name2qmap["dynamic"]
482-
code = code.to(A.device)
483545

484546
if out is None:
485547
out = torch.zeros_like(A, dtype=torch.float32)
486548
if quant_state is None:
487549
quant_state = (absmax, code)
550+
else:
551+
absmax, code = quant_state
488552

489553

490554
if A.device.type != 'cpu':
491-
if blocksize not in [2048, 4096]:
492-
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
555+
device = pre_call(A.device)
556+
code = code.to(A.device)
557+
if blocksize not in [2048, 4096, 1024, 512]:
558+
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
493559
is_on_gpu([A, out])
494560
if out.dtype == torch.float32:
495-
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
561+
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
496562
elif out.dtype == torch.float16:
497-
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
563+
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
498564
else:
499-
raise ValueError(
500-
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
501-
)
565+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
566+
post_call(A.device)
502567
else:
568+
code = code.cpu()
503569
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
504570

505571
return out

csrc/kernels.cu

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
428428
}
429429

430430
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
431-
__launch_bounds__(TH, 4)
431+
//__launch_bounds__(TH, 4)
432432
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
433433
{
434434
const int n_full = gridDim.x * BLOCK_SIZE;
435435
int valid_items = 0;
436436
const int base_idx = (blockIdx.x * BLOCK_SIZE);
437437

438-
T vals[NUM];
439-
float rand_vals[NUM];
440-
unsigned char qvals[NUM];
438+
T vals[NUM_PER_TH];
439+
float rand_vals[NUM_PER_TH];
440+
unsigned char qvals[NUM_PER_TH];
441441
//float local_abs_max = -FLT_MAX;
442442
float local_abs_max = 0.0f;
443443
int local_rand_idx = 0;
@@ -510,26 +510,27 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
510510
}
511511

512512
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
513-
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
513+
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
514514
{
515515

516516
const int n_full = gridDim.x * BLOCK_SIZE;
517517
int valid_items = 0;
518518
const int base_idx = (blockIdx.x * BLOCK_SIZE);
519519

520-
T vals[NUM];
521-
unsigned char qvals[NUM];
520+
T vals[NUM_PER_TH];
521+
unsigned char qvals[NUM_PER_TH];
522522
float local_abs_max = -FLT_MAX;
523523

524524
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
525525
typedef cub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
526526

527527
__shared__ typename LoadChar::TempStorage loadchar;
528528
__shared__ typename StoreT::TempStorage storet;
529-
__shared__ float smem_code[256];
529+
//__shared__ float smem_code[256];
530+
//float local_code[16];
530531

531-
if(threadIdx.x < 256)
532-
smem_code[threadIdx.x] = code[threadIdx.x];
532+
//if(threadIdx.x < 256)
533+
//smem_code[threadIdx.x] = code[threadIdx.x];
533534

534535
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
535536
{
@@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
539540
__syncthreads();
540541
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
541542

543+
// load code through read-only cache via __ldg
542544
#pragma unroll NUM_PER_TH
543545
for(int j = 0; j < NUM_PER_TH; j++)
544-
vals[j] = smem_code[qvals[j]]*local_abs_max;
546+
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
545547

546548
__syncthreads();
547549
StoreT(storet).Store(&(out[i]), vals, valid_items);
@@ -2791,11 +2793,21 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
27912793
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
27922794
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
27932795
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2794-
2795-
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
2796-
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
2797-
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
2798-
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
2796+
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2797+
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2798+
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2799+
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2800+
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2801+
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2802+
2803+
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2804+
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2805+
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2806+
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2807+
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2808+
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
2809+
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
2810+
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
27992811

28002812

28012813

0 commit comments

Comments
 (0)