Skip to content

Commit b1c4adc

Browse files
Type annotations, cleanup
1 parent b5d6135 commit b1c4adc

File tree

5 files changed

+31
-49
lines changed

5 files changed

+31
-49
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -244,25 +244,26 @@ def get_tile_inds(format, device):
244244
@dataclass
245245
class MatmulLtState:
246246
_tile_indices: Optional[torch.Tensor] = None
247+
247248
force_no_igemmlt: bool = False
248-
CB = None
249-
CxB = None # TODO: Deprecate/remove
250-
SB = None
251-
SCB = None
252249

253-
CxBt = None # TODO: Deprecate/remove
254-
SBt = None
255-
CBt = None
250+
CB: Optional[torch.Tensor] = None
251+
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
252+
SB: Optional[torch.Tensor] = None
253+
SCB: Optional[torch.Tensor] = None
254+
255+
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
256+
SBt: Optional[torch.Tensor] = None
257+
CBt: Optional[torch.Tensor] = None
256258

257-
subB = None
259+
subB: Optional[torch.Tensor] = None
258260

259-
outlier_pool = None
261+
outlier_pool: Optional[GlobalOutlierPooler] = None
260262
has_accumulated_gradients = False
261263
threshold = 0.0
262-
idx = None
264+
idx: Optional[torch.Tensor] = None
263265
is_training = True
264266
has_fp16_weights = True
265-
memory_efficient_backward = False
266267
use_pool = False
267268
formatB = "row" # TODO: Deprecate/remove
268269

@@ -313,10 +314,10 @@ def forward(
313314
if A.dtype != torch.float16:
314315
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
315316

316-
# 1. Quantize A. Note that as a side-effect, outliers are suppressed.
317317
if len(A.shape) == 3:
318318
A = A.reshape(-1, A.shape[-1])
319319

320+
# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
320321
if ctx.needs_input_grad[1]:
321322
# Slower path
322323
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
@@ -366,6 +367,8 @@ def forward(
366367

367368
# 3. Int8 Matmul
368369
out32 = F.int8_linear_matmul(CA, state.CB)
370+
371+
# Dequantize matmul result
369372
if bias is None or bias.dtype == torch.float16:
370373
# we apply the fused bias here
371374
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
@@ -375,7 +378,7 @@ def forward(
375378

376379
# 4. Mixed-precision decomposition matmul
377380
if subA is not None and state.subB is not None:
378-
output += torch.matmul(subA, state.subB.to(subA.dtype))
381+
output += torch.matmul(subA, state.subB)
379382

380383
# 5. Save state
381384
ctx.state = state
@@ -399,15 +402,15 @@ def forward(
399402
return output
400403

401404
@staticmethod
402-
def backward(ctx, grad_output):
405+
def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
403406
if ctx.is_empty:
404407
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
405408
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
406409

407410
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
408411
CAt, subA, A = ctx.tensors
409412
SCAt, idx = ctx.tensor_states
410-
state = ctx.state
413+
state: MatmulLtState = ctx.state
411414
grad_A = grad_B = grad_bias = None
412415

413416
if req_gradBias:
@@ -499,7 +502,7 @@ def matmul(
499502
out: Optional[torch.Tensor] = None,
500503
state: Optional[MatmulLtState] = None,
501504
threshold=0.0,
502-
bias=None,
505+
bias: Optional[torch.Tensor] = None,
503506
):
504507
state = state or MatmulLtState()
505508
if threshold > 0.0:
@@ -512,7 +515,7 @@ def matmul_4bit(
512515
B: torch.Tensor,
513516
quant_state: F.QuantState,
514517
out: Optional[torch.Tensor] = None,
515-
bias=None,
518+
bias: Optional[torch.Tensor] = None,
516519
):
517520
assert quant_state is not None
518521

bitsandbytes/cextension.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,3 @@
1-
"""
2-
extract factors the build is dependent on:
3-
[X] compute capability
4-
[ ] TODO: Q - What if we have multiple GPUs of different makes?
5-
- CUDA version
6-
- Software:
7-
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
8-
- CuBLAS-LT: full-build 8-bit optimizer
9-
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
10-
11-
evaluation:
12-
- if paths faulty, return meaningful error
13-
- else:
14-
- determine CUDA version
15-
- determine capabilities
16-
- based on that set the default path
17-
"""
18-
191
import ctypes as ct
202
import logging
213
import os

bitsandbytes/functional.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,7 +2279,9 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
22792279
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
22802280
ldc = shapeC[-1] # Output (batch, tokens, outputs)
22812281

2282-
assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
2282+
assert (
2283+
lda == ldb
2284+
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
22832285

22842286
is_on_gpu([A, B, out])
22852287

@@ -2288,7 +2290,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
22882290
ptrA = get_ptr(A)
22892291
ptrB = get_ptr(B)
22902292
ptrC = get_ptr(out)
2291-
ptrRowScale = get_ptr(None)
2293+
ptrRowScale = None
22922294
m = ct.c_int32(m)
22932295
n = ct.c_int32(n)
22942296
k = ct.c_int32(k)
@@ -2303,7 +2305,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
23032305
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
23042306

23052307
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
2306-
raise NotImplementedError("igemmlt not implemented!")
2308+
raise NotImplementedError("int8_linear_matmul not implemented!")
23072309

23082310
if has_error:
23092311
raise RuntimeError(
@@ -2369,7 +2371,7 @@ def get_colrow_absmax(
23692371
col_stats: Optional[torch.Tensor] = None,
23702372
nnz_block_ptr: Optional[torch.Tensor] = None,
23712373
threshold=0.0,
2372-
):
2374+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
23732375
# Note: prior impl only works with fp16
23742376
assert A.is_floating_point()
23752377

@@ -2395,7 +2397,7 @@ def get_colrow_absmax(
23952397
return row_stats, col_stats, outlier_mask
23962398

23972399

2398-
def get_row_absmax(A, threshold=0.0):
2400+
def get_row_absmax(A: torch.Tensor, threshold=0.0):
23992401
assert A.dtype == torch.float16
24002402

24012403
rows = prod(A.shape[:-1])

bitsandbytes/nn/modules.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,11 +566,11 @@ def __init__(
566566
class Int8Params(torch.nn.Parameter):
567567
def __new__(
568568
cls,
569-
data=None,
569+
data: Optional[torch.Tensor] = None,
570570
requires_grad=True,
571571
has_fp16_weights=False,
572-
CB=None,
573-
SCB=None,
572+
CB: Optional[torch.Tensor] = None,
573+
SCB: Optional[torch.Tensor] = None,
574574
):
575575
if data is None:
576576
data = torch.empty(0)
@@ -881,7 +881,6 @@ def __init__(
881881
output_features: int,
882882
bias=True,
883883
has_fp16_weights=True,
884-
memory_efficient_backward=False,
885884
threshold=0.0,
886885
index=None,
887886
device=None,
@@ -898,13 +897,12 @@ def __init__(
898897
Whether the linear class uses the bias term as well.
899898
"""
900899
super().__init__(input_features, output_features, bias, device)
901-
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
902900
self.state = bnb.MatmulLtState()
903901
self.index = index
904902

905903
self.state.threshold = threshold
906904
self.state.has_fp16_weights = has_fp16_weights
907-
self.state.memory_efficient_backward = memory_efficient_backward
905+
908906
if threshold > 0.0 and not has_fp16_weights:
909907
self.state.use_pool = True
910908

bitsandbytes/research/autograd/_functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,6 @@ def backward(ctx, grad_output):
328328
grad_B = torch.matmul(grad_output.t(), A)
329329

330330
if req_gradA:
331-
# if state.CBt is not None:
332-
# gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
333-
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
334331
if state.CB is not None:
335332
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
336333
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

0 commit comments

Comments
 (0)