Skip to content

Commit 61a4a20

Browse files
committed
use QuantState class for quant_state
1 parent e812136 commit 61a4a20

File tree

4 files changed

+97
-102
lines changed

4 files changed

+97
-102
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -496,15 +496,15 @@ class MatMul4Bit(torch.autograd.Function):
496496
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
497497

498498
@staticmethod
499-
def forward(ctx, A, B, out=None, bias=None, state=None):
499+
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
500500
# default of pytorch behavior if inputs are empty
501501
ctx.is_empty = False
502502
if prod(A.shape) == 0:
503503
ctx.is_empty = True
504504
ctx.A = A
505505
ctx.B = B
506506
ctx.bias = bias
507-
B_shape = state[1]
507+
B_shape = quant_state.shape
508508
if A.shape[-1] == B_shape[0]:
509509
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
510510
else:
@@ -513,10 +513,10 @@ def forward(ctx, A, B, out=None, bias=None, state=None):
513513

514514
# 1. Dequantize
515515
# 2. MatmulnN
516-
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
516+
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
517517

518518
# 3. Save state
519-
ctx.state = state
519+
ctx.state = quant_state
520520
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
521521

522522
if any(ctx.needs_input_grad[:2]):
@@ -534,7 +534,6 @@ def backward(ctx, grad_output):
534534

535535
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
536536
A, B = ctx.tensors
537-
state = ctx.state
538537

539538
grad_A, grad_B, grad_bias = None, None, None
540539

@@ -563,15 +562,14 @@ def matmul(
563562
return MatMul8bitLt.apply(A, B, out, bias, state)
564563

565564

566-
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
565+
def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
567566
assert quant_state is not None
568567
if A.numel() == A.shape[-1] and A.requires_grad == False:
569-
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
570-
if A.shape[-1] % blocksize != 0:
571-
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
568+
if A.shape[-1] % quant_state.blocksize != 0:
569+
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
572570
return MatMul4Bit.apply(A, B, out, bias, quant_state)
573571
else:
574-
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
572+
out = F.gemv_4bit(A, B.t(), out, quant_state=quant_state)
575573
if bias is not None:
576574
out += bias
577575
return out

bitsandbytes/functional.py

Lines changed: 85 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,25 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
566566

567567
return out
568568

569+
class QuantState:
570+
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
571+
self.absmax = absmax
572+
self.shape = shape
573+
self.code = code
574+
self.dtype = dtype
575+
self.blocksize = blocksize
576+
self.quant_type = quant_type
577+
self.offset = offset
578+
self.state2 = state2
579+
self.nested = state2 is not None
580+
581+
def to(self, device):
582+
# make sure the quantization state is on the right device
583+
self.absmax = self.absmax.to(device)
584+
if self.nested:
585+
self.offset = self.offset.to(device)
586+
self.state2.absmax = self.state2.absmax.to(device)
587+
self.state2.code = self.state2.code.to(device)
569588

570589
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
571590
"""
@@ -633,16 +652,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
633652
offset = absmax.mean()
634653
absmax -= offset
635654
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
636-
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2]
655+
quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2)
637656
else:
638-
state = [absmax, code, blocksize, nested, A.dtype, None, None]
657+
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
639658

640-
return out, state
659+
return out, quant_state
641660

642661

643662
def dequantize_blockwise(
644663
A: Tensor,
645-
quant_state: Tuple[Tensor, Tensor] = None,
664+
quant_state: QuantState = None,
646665
absmax: Tensor = None,
647666
code: Tensor = None,
648667
out: Tensor = None,
@@ -659,8 +678,8 @@ def dequantize_blockwise(
659678
----------
660679
A : torch.Tensor
661680
The input 8-bit tensor.
662-
quant_state : tuple(torch.Tensor, torch.Tensor)
663-
Tuple of code and absmax values.
681+
quant_state : QuantState
682+
Object with code, absmax and other quantization state components.
664683
absmax : torch.Tensor
665684
The absmax values.
666685
code : torch.Tensor
@@ -681,36 +700,35 @@ def dequantize_blockwise(
681700
code = name2qmap["dynamic"]
682701

683702
if quant_state is None:
684-
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
685-
686-
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
687-
688-
if nested:
689-
absmax = dequantize_blockwise(absmax, state2)
690-
absmax += offset
703+
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
704+
705+
absmax = quant_state.absmax
706+
if quant_state.nested:
707+
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
708+
absmax += quant_state.offset
691709
if absmax.dtype != torch.float32: absmax = absmax.float()
692710

693711
if out is None:
694-
out = torch.empty(A.shape, dtype=dtype, device=A.device)
712+
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
695713

696714
if A.device.type != 'cpu':
697715
device = pre_call(A.device)
698-
code = code.to(A.device)
699-
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
700-
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
716+
code = quant_state.code.to(A.device)
717+
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
718+
raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
701719
is_on_gpu([A, absmax, out])
702720
if out.dtype == torch.float32:
703-
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
721+
lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
704722
elif out.dtype == torch.float16:
705-
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
723+
lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
706724
elif out.dtype == torch.bfloat16:
707-
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
725+
lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
708726
else:
709727
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
710728
post_call(A.device)
711729
else:
712-
code = code.cpu()
713-
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
730+
code = quant_state.code.cpu()
731+
lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel()))
714732

715733
return out
716734

@@ -839,26 +857,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
839857
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
840858
post_call(A.device)
841859

842-
datatype = get_4bit_type(quant_type, device=A.device)
860+
code = get_4bit_type(quant_type, device=A.device)
843861

844862
if compress_statistics:
845863
offset = absmax.mean()
846864
absmax -= offset
847865
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
848866
del absmax
849-
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
867+
state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
850868
else:
851-
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype]
869+
state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
852870

853871
return out, state
854872

855-
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
873+
def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
856874
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
857875

858-
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
876+
def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
859877
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
860878

861-
def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
879+
def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
862880
"""
863881
Dequantizes FP4 blockwise quantized values.
864882
@@ -868,8 +886,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
868886
----------
869887
A : torch.Tensor
870888
The input 8-bit tensor (packed 4-bit values).
871-
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
872-
Tuple of absmax values, original tensor shape and original dtype.
889+
quant_state : QuantState
890+
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
873891
absmax : torch.Tensor
874892
The absmax values.
875893
out : torch.Tensor
@@ -892,41 +910,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
892910

893911
if quant_state is None:
894912
assert absmax is not None and out is not None
895-
shape = out.shape
896-
dtype = out.dtype
913+
914+
quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type)
915+
897916
else:
898-
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
917+
absmax = quant_state.absmax
899918

900919

901-
if compressed_stats is not None:
902-
offset, state2 = compressed_stats
903-
absmax = dequantize_blockwise(absmax, state2)
904-
absmax += offset
920+
if quant_state.nested:
921+
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
922+
absmax += quant_state.offset
905923
if absmax.dtype != torch.float32: absmax = absmax.float()
906924

907925
if out is None:
908-
out = torch.empty(shape, dtype=dtype, device=A.device)
926+
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
909927

910928
n = out.numel()
911929

912-
913930
device = pre_call(A.device)
914931
is_on_gpu([A, absmax, out])
915932
if out.dtype == torch.float32:
916-
if quant_type == 'fp4':
917-
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
933+
if quant_state.quant_type == 'fp4':
934+
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
918935
else:
919-
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
936+
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
920937
elif out.dtype == torch.float16:
921-
if quant_type == 'fp4':
922-
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
938+
if quant_state.quant_type == 'fp4':
939+
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
923940
else:
924-
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
941+
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
925942
elif out.dtype == torch.bfloat16:
926-
if quant_type == 'fp4':
927-
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
943+
if quant_state.quant_type == 'fp4':
944+
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
928945
else:
929-
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
946+
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
930947
else:
931948
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
932949
post_call(A.device)
@@ -952,22 +969,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
952969

953970
def dequantize(
954971
A: Tensor,
955-
quant_state: Tuple[Tensor, Tensor] = None,
972+
state: Tuple[Tensor, Tensor] = None,
956973
absmax: Tensor = None,
957974
code: Tensor = None,
958975
out: Tensor = None,
959976
) -> Tensor:
960-
assert quant_state is not None or absmax is not None
961-
if code is None and quant_state is None:
977+
assert state is not None or absmax is not None
978+
if code is None and state is None:
962979
if "dynamic" not in name2qmap:
963980
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
964981
code = name2qmap["dynamic"]
965982
code = code.to(A.device)
966983

967-
if quant_state is None:
968-
quant_state = (absmax, code)
969-
out = dequantize_no_absmax(A, quant_state[1], out)
970-
return out * quant_state[0]
984+
if state is None:
985+
state = (absmax, code)
986+
out = dequantize_no_absmax(A, state[1], out)
987+
return out * state[0]
971988

972989

973990
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
@@ -1472,23 +1489,22 @@ def gemv_4bit(
14721489
out: Tensor = None,
14731490
transposed_A=False,
14741491
transposed_B=False,
1475-
state=None
1492+
quant_state=None
14761493
):
14771494
prev_device = pre_call(A.device)
14781495
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
1479-
if state is None:
1496+
if quant_state is None:
14801497
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
14811498

14821499
if A.numel() != A.shape[-1]:
14831500
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
14841501

1485-
Bshape = state[1]
1502+
Bshape = quant_state.shape
14861503
bout = Bshape[0]
1487-
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
1488-
if compressed_stats is not None:
1489-
offset, state2 = compressed_stats
1490-
absmax = dequantize_blockwise(absmax, state2)
1491-
absmax += offset
1504+
absmax = quant_state.absmax
1505+
if quant_state.nested:
1506+
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
1507+
absmax += quant_state.offset
14921508

14931509
if out is None:
14941510
if len(A.shape) == 3:
@@ -1502,7 +1518,7 @@ def gemv_4bit(
15021518
lda = Bshape[0]
15031519
ldc = Bshape[0]
15041520
ldb = (A.shape[-1]+1)//2
1505-
is_on_gpu([B, A, out, absmax, state[-1]])
1521+
is_on_gpu([B, A, out, absmax, quant_state.code])
15061522
m = ct.c_int32(m)
15071523
n = ct.c_int32(n)
15081524
k = ct.c_int32(k)
@@ -1512,11 +1528,11 @@ def gemv_4bit(
15121528

15131529
if B.dtype == torch.uint8:
15141530
if A.dtype == torch.float16:
1515-
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
1531+
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
15161532
elif A.dtype == torch.bfloat16:
1517-
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
1533+
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
15181534
elif A.dtype == torch.float32:
1519-
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
1535+
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
15201536
else:
15211537
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
15221538

@@ -1798,7 +1814,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
17981814

17991815
def mm_dequant(
18001816
A,
1801-
quant_state,
1817+
state,
18021818
row_stats,
18031819
col_stats,
18041820
out=None,
@@ -1808,7 +1824,7 @@ def mm_dequant(
18081824
):
18091825
assert A.dtype == torch.int32
18101826
if bias is not None: assert bias.dtype == torch.float16
1811-
out_shape = quant_state[0]
1827+
out_shape = state[0]
18121828
if len(out_shape) == 3:
18131829
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
18141830

0 commit comments

Comments
 (0)