Skip to content

Commit 6747525

Browse files
committed
Added documentation for NF4; failing 8-bit matmul; fixed absmax bug. #529 #543
1 parent 8a20cd8 commit 6747525

File tree

4 files changed

+29
-2
lines changed

4 files changed

+29
-2
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,8 @@ Features:
267267

268268
Bug fixes:
269269
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
270+
- Fixed a missing scipy dependency in requirements.txt. #544
271+
272+
Documentation:
273+
- Improved documentation for GPUs that do not support 8-bit matmul. #529
274+
- Added description and pointers for the NF4 data type. #543

bitsandbytes/cuda_setup/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def is_cublasLt_compatible(cc):
163163
if cc is not None:
164164
cc_major, cc_minor = cc.split('.')
165165
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
166-
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
166+
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! \
167+
If you run into issues with 8-bit matmul, you can try 4-bit quantization: https://huggingface.co/blog/4bit-transformers-bitsandbytes", is_warning=True)
167168
else:
168169
has_cublaslt = True
169170
return has_cublaslt

bitsandbytes/functional.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,16 @@ def get_4bit_type(typename, device=None, blocksize=64):
718718
if device is None: device = 'cuda'
719719
data = None
720720
if typename == 'nf4':
721+
''' Implements the NF4 data type.
722+
723+
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
724+
is normalized into the range [-1, 1].
725+
726+
For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)
727+
728+
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
729+
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
730+
'''
721731
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
722732
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
723733
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
@@ -731,6 +741,7 @@ def get_4bit_type(typename, device=None, blocksize=64):
731741
# 0b101 = 6
732742
# 0b110 = 2
733743
# 0b111 = 3
744+
# can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
734745
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
735746
elif typename == 'int4':
736747
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
@@ -888,10 +899,10 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
888899

889900

890901
if compressed_stats is not None:
891-
if absmax.dtype != torch.float32: absmax = absmax.float()
892902
offset, state2 = compressed_stats
893903
absmax = dequantize_blockwise(absmax, state2)
894904
absmax += offset
905+
if absmax.dtype != torch.float32: absmax = absmax.float()
895906

896907
if out is None:
897908
out = torch.empty(shape, dtype=dtype, device=A.device)

bitsandbytes/nn/modules.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,16 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
229229
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
230230

231231
class LinearNF4(Linear4bit):
232+
''' Implements the NF4 data type.
233+
234+
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
235+
is normalized into the range [-1, 1].
236+
237+
For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)
238+
239+
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
240+
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
241+
'''
232242
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
233243
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
234244

0 commit comments

Comments
 (0)