Skip to content

Commit eed9c3c

Browse files
Improve docs and tests
1 parent f61d8bc commit eed9c3c

File tree

5 files changed

+88
-29
lines changed

5 files changed

+88
-29
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def forward(
320320
# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
321321
if ctx.needs_input_grad[1]:
322322
# Slower path
323-
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
323+
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
324324
else:
325325
# Fast path
326326
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
@@ -422,7 +422,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
422422
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
423423

424424
if req_gradB:
425-
Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
425+
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
426426

427427
gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
428428
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)

bitsandbytes/functional.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
442442
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
443443
444444
Args:
445-
tensors (Iterable[Optional[torch.Tensor]]): A list of tensors to verify.
445+
tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify.
446446
447447
Raises:
448448
`RuntimeError`: Raised when the verification fails.
@@ -2572,13 +2572,80 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
25722572
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
25732573

25742574

2575+
@deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning)
25752576
def double_quant(
25762577
A: torch.Tensor,
25772578
col_stats: Optional[torch.Tensor] = None,
25782579
row_stats: Optional[torch.Tensor] = None,
25792580
out_col: Optional[torch.Tensor] = None,
25802581
out_row: Optional[torch.Tensor] = None,
25812582
threshold=0.0,
2583+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]:
2584+
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2585+
2586+
The statistics are determined both row-wise and column-wise (transposed).
2587+
2588+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2589+
2590+
<Tip warning={true}>
2591+
This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead.
2592+
The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index.
2593+
</Tip>
2594+
2595+
Args:
2596+
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
2597+
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
2598+
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
2599+
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
2600+
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
2601+
threshold (`float`, *optional*):
2602+
An optional threshold for sparse decomposition of outlier features.
2603+
2604+
No outliers are held back when 0.0. Defaults to 0.0.
2605+
2606+
Returns:
2607+
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
2608+
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
2609+
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
2610+
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
2611+
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
2612+
- `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor.
2613+
"""
2614+
2615+
coo_tensor = None
2616+
quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant(
2617+
A,
2618+
col_stats,
2619+
row_stats,
2620+
out_col,
2621+
out_row,
2622+
threshold=threshold,
2623+
)
2624+
2625+
if threshold > 0.0:
2626+
# Build COO tensor for any outliers.
2627+
outlier_mask = A.abs() >= threshold
2628+
outlier_locations = outlier_mask.nonzero()
2629+
outliers = A[outlier_mask]
2630+
coo_tensor = COOSparseTensor(
2631+
A.shape[0],
2632+
A.shape[1],
2633+
outliers.numel(),
2634+
outlier_locations[:, 0].int(),
2635+
outlier_locations[:, 1].int(),
2636+
outliers,
2637+
)
2638+
2639+
return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
2640+
2641+
2642+
def int8_double_quant(
2643+
A: torch.Tensor,
2644+
col_stats: Optional[torch.Tensor] = None,
2645+
row_stats: Optional[torch.Tensor] = None,
2646+
out_col: Optional[torch.Tensor] = None,
2647+
out_row: Optional[torch.Tensor] = None,
2648+
threshold=0.0,
25822649
):
25832650
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
25842651
@@ -2612,7 +2679,6 @@ def double_quant(
26122679
"""
26132680

26142681
# TODO: Optimize/write CUDA kernel for this?
2615-
# Note: for inference, use the new int8_vectorwise_quant.
26162682

26172683
# Use CUDA kernel for rowwise and COO tensor
26182684
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)
@@ -2665,8 +2731,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
26652731
# TODO we could improve perf of this
26662732
outliers = A.abs() >= threshold
26672733

2668-
# argwhere needs host/device sync, so we skip when
2669-
# there aren't actually any outliers.
26702734
if outliers.any():
26712735
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
26722736

bitsandbytes/research/autograd/_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non
215215
# 1. Quantize A
216216
if len(A.shape) == 3:
217217
A = A.view(-1, A.shape[-1]).contiguous()
218-
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
218+
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
219219

220220
if state.threshold > 0.0 and outlier_cols is not None:
221221
if state.has_fp16_weights:
@@ -248,7 +248,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non
248248
state.SCB,
249249
state.SCBt,
250250
_,
251-
) = F.double_quant(B.to(torch.float16))
251+
) = F.int8_double_quant(B.to(torch.float16))
252252
state.SB = (state.CB.shape, "row")
253253
else:
254254
has_grad = False
@@ -320,7 +320,7 @@ def backward(ctx, grad_output):
320320
if len(grad_output.shape) == 3:
321321
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
322322

323-
Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16))
323+
Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.int8_double_quant(grad_output.to(torch.float16))
324324

325325
if req_gradB:
326326
# print('back A shape', A.shape)

tests/test_functional.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,8 @@ def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):
606606

607607
A = A.view(-1, A.shape[-1])
608608

609-
CA, _, statsA, _, _ = F.double_quant(A)
610-
CB, _, statsB, _, _ = F.int8_vectorwise_quant(B)
609+
CA, _, statsA, _, _ = F.int8_double_quant(A)
610+
CB, statsB, _ = F.int8_vectorwise_quant(B)
611611
output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
612612

613613
torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
@@ -863,7 +863,7 @@ def test_double_quant(dim1, dim2):
863863
out_col1, Scol = F.vectorwise_quant(A, dim=0)
864864
out_row1, Srow = F.vectorwise_quant(A, dim=1)
865865

866-
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
866+
CA, CAt, statsA, statsAt, coo_tensor = F.int8_double_quant(A)
867867

868868
# max difference is 1 due to rounding differences
869869
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
@@ -953,7 +953,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
953953

954954
out1 = torch.matmul(A.half(), B.t().half())
955955

956-
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
956+
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
957957
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
958958
A2, SA = F.nvidia_transform(C1a, "col32")
959959
B2, SB = F.nvidia_transform(CB, formatB)
@@ -1032,7 +1032,7 @@ def test_row_scale_bench(dim1, dim4, inner):
10321032
torch.cuda.synchronize()
10331033
print("16", time.time() - t0)
10341034

1035-
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1035+
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
10361036
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
10371037
A2, SA = F.nvidia_transform(C1a, "col32")
10381038
B2, SB = F.nvidia_transform(CB, formatB)
@@ -1047,7 +1047,7 @@ def test_row_scale_bench(dim1, dim4, inner):
10471047
torch.cuda.synchronize()
10481048
print("row-wise", time.time() - t0)
10491049

1050-
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
1050+
C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B)
10511051
B2, SB = F.nvidia_transform(C2a, formatB)
10521052
torch.cuda.synchronize()
10531053
t0 = time.time()
@@ -1115,7 +1115,8 @@ def test_coo_double_quant(dim1, dim2):
11151115

11161116
if coo_tensor is not None:
11171117
A1 = A * idx
1118-
A2 = coo_tensor.to_dense()
1118+
A2 = torch.zeros_like(A)
1119+
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
11191120
torch.testing.assert_close(A1, A2)
11201121

11211122
A1 = A * (idx == 0)
@@ -1133,14 +1134,9 @@ def test_coo_int8_vectorwise_quant(dim1, dim2):
11331134
A = torch.randn(dim1, dim2, device="cuda").half()
11341135

11351136
idx = torch.abs(A) >= threshold
1136-
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)
1137+
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
11371138

1138-
if coo_tensor is not None:
1139-
A1 = A * idx
1140-
A2 = coo_tensor.to_dense()
1141-
torch.testing.assert_close(A1, A2)
1142-
1143-
A1 = A * (idx == 0)
1139+
if outlier_cols is not None:
11441140
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
11451141
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
11461142

@@ -1230,13 +1226,14 @@ def test_integrated_sparse_decomp(dim1, dim2):
12301226
w1 = torch.randn(dim1, dim2).cuda().half()
12311227
out1 = torch.matmul(A, w1.t())
12321228

1233-
Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1)
1234-
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A)
1229+
Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
1230+
CA, statsA, _ = F.int8_vectorwise_quant(A)
12351231

12361232
out1_32 = F.int8_linear_matmul(CA, Cw1)
12371233
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
12381234

1239-
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)
1235+
# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
1236+
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
12401237

12411238
out1_32 = F.int8_linear_matmul(CA, Cw1)
12421239
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
@@ -1377,7 +1374,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
13771374
torch.nn.init.xavier_uniform_(B)
13781375
Bt = B.t().contiguous()
13791376

1380-
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
1377+
CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)
13811378

13821379
rowidx = torch.randint(0, A.shape[-1], size=(15,))
13831380

tests/test_modules.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,15 +356,13 @@ def test_linear8bitlt_accumulated_gradient():
356356

357357

358358
@pytest.mark.parametrize("threshold", [0.0, 2.0])
359-
@pytest.mark.parametrize("memory_efficient_backward", [False])
360359
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
361360
l1 = (
362361
bnb.nn.Linear8bitLt(
363362
32,
364363
64,
365364
threshold=threshold,
366365
has_fp16_weights=False,
367-
memory_efficient_backward=memory_efficient_backward,
368366
)
369367
.cuda()
370368
.half()

0 commit comments

Comments
 (0)