Skip to content

Commit ad6eef9

Browse files
authored
Merge pull request #753 from poedator/save4
Save and load in NF4 / FP4 formats
2 parents e812136 + 851806e commit ad6eef9

File tree

6 files changed

+376
-95
lines changed

6 files changed

+376
-95
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ dmypy.json
133133

134134
dependencies
135135
cuda_build
136+
.vscode/*

bitsandbytes/autograd/_functions.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -496,15 +496,15 @@ class MatMul4Bit(torch.autograd.Function):
496496
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
497497

498498
@staticmethod
499-
def forward(ctx, A, B, out=None, bias=None, state=None):
499+
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
500500
# default of pytorch behavior if inputs are empty
501501
ctx.is_empty = False
502502
if prod(A.shape) == 0:
503503
ctx.is_empty = True
504504
ctx.A = A
505505
ctx.B = B
506506
ctx.bias = bias
507-
B_shape = state[1]
507+
B_shape = quant_state.shape
508508
if A.shape[-1] == B_shape[0]:
509509
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
510510
else:
@@ -513,10 +513,10 @@ def forward(ctx, A, B, out=None, bias=None, state=None):
513513

514514
# 1. Dequantize
515515
# 2. MatmulnN
516-
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
516+
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
517517

518518
# 3. Save state
519-
ctx.state = state
519+
ctx.state = quant_state
520520
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
521521

522522
if any(ctx.needs_input_grad[:2]):
@@ -534,7 +534,6 @@ def backward(ctx, grad_output):
534534

535535
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
536536
A, B = ctx.tensors
537-
state = ctx.state
538537

539538
grad_A, grad_B, grad_bias = None, None, None
540539

@@ -563,12 +562,11 @@ def matmul(
563562
return MatMul8bitLt.apply(A, B, out, bias, state)
564563

565564

566-
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
565+
def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
567566
assert quant_state is not None
568567
if A.numel() == A.shape[-1] and A.requires_grad == False:
569-
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
570-
if A.shape[-1] % blocksize != 0:
571-
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
568+
if A.shape[-1] % quant_state.blocksize != 0:
569+
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
572570
return MatMul4Bit.apply(A, B, out, bias, quant_state)
573571
else:
574572
out = F.gemv_4bit(A, B.t(), out, state=quant_state)

0 commit comments

Comments
 (0)