Skip to content

Commit 4c11d6d

Browse files
committed
reverted fn signatures in functional()
1 parent 1d9f0f2 commit 4c11d6d

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = N
569569
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
570570
return MatMul4Bit.apply(A, B, out, bias, quant_state)
571571
else:
572-
out = F.gemv_4bit(A, B.t(), out, quant_state=quant_state)
572+
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
573573
if bias is not None:
574574
out += bias
575575
return out

bitsandbytes/functional.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,22 +1579,22 @@ def gemv_4bit(
15791579
out: Tensor = None,
15801580
transposed_A=False,
15811581
transposed_B=False,
1582-
quant_state=None
1582+
state=None
15831583
):
15841584
prev_device = pre_call(A.device)
15851585
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
1586-
if quant_state is None:
1586+
if state is None:
15871587
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
15881588

15891589
if A.numel() != A.shape[-1]:
15901590
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
15911591

1592-
Bshape = quant_state.shape
1592+
Bshape = state.shape
15931593
bout = Bshape[0]
1594-
absmax = quant_state.absmax
1595-
if quant_state.nested:
1596-
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
1597-
absmax += quant_state.offset
1594+
absmax = state.absmax
1595+
if state.nested:
1596+
absmax = dequantize_blockwise(state.absmax, state.state2)
1597+
absmax += state.offset
15981598

15991599
if out is None:
16001600
if len(A.shape) == 3:
@@ -1608,7 +1608,7 @@ def gemv_4bit(
16081608
lda = Bshape[0]
16091609
ldc = Bshape[0]
16101610
ldb = (A.shape[-1]+1)//2
1611-
is_on_gpu([B, A, out, absmax, quant_state.code])
1611+
is_on_gpu([B, A, out, absmax, state.code])
16121612
m = ct.c_int32(m)
16131613
n = ct.c_int32(n)
16141614
k = ct.c_int32(k)
@@ -1618,11 +1618,11 @@ def gemv_4bit(
16181618

16191619
if B.dtype == torch.uint8:
16201620
if A.dtype == torch.float16:
1621-
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
1621+
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
16221622
elif A.dtype == torch.bfloat16:
1623-
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
1623+
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
16241624
elif A.dtype == torch.float32:
1625-
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(quant_state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(quant_state.blocksize))
1625+
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
16261626
else:
16271627
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
16281628

@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
19041904

19051905
def mm_dequant(
19061906
A,
1907-
state,
1907+
quant_state,
19081908
row_stats,
19091909
col_stats,
19101910
out=None,
@@ -1914,7 +1914,7 @@ def mm_dequant(
19141914
):
19151915
assert A.dtype == torch.int32
19161916
if bias is not None: assert bias.dtype == torch.float16
1917-
out_shape = state[0]
1917+
out_shape = quant_state[0]
19181918
if len(out_shape) == 3:
19191919
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
19201920

tests/test_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
24012401

24022402
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
24032403
C3 = torch.matmul(A, B.t())
2404-
C2 = F.gemv_4bit(A, qB.t(), quant_state=state)
2404+
C2 = F.gemv_4bit(A, qB.t(), state=state)
24052405
A.requires_grad = True
24062406
C1 = bnb.matmul_4bit(A, qB.t(), state)
24072407

0 commit comments

Comments
 (0)