Skip to content

Commit fdf4745

Browse files
int8: more tests passing, cleanup
1 parent 0ab14fe commit fdf4745

File tree

6 files changed

+93
-47
lines changed

6 files changed

+93
-47
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,14 @@ def forward(
308308
# 1. Quantize A
309309
if len(A.shape) == 3:
310310
A = A.reshape(-1, A.shape[-1])
311-
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
311+
312+
if ctx.needs_input_grad[1]:
313+
# Slower path
314+
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
315+
else:
316+
# Fast path
317+
CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
318+
CAt = SCAt = None
312319

313320
has_grad = False
314321

@@ -322,20 +329,24 @@ def forward(
322329
state.reset_grads()
323330

324331
# 2. Quantize B
325-
(
326-
state.CB,
327-
state.CBt,
328-
state.SCB,
329-
state.SCBt,
330-
_,
331-
) = F.double_quant(B.to(torch.float16))
332+
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
333+
334+
# (
335+
# state.CB,
336+
# state.CBt,
337+
# state.SCB,
338+
# state.SCBt,
339+
# _,
340+
# ) = F.double_quant(B.to(torch.float16))
332341

333342
if state.threshold > 0.0 and coo_tensorA is not None:
334343
state.idx = torch.unique(coo_tensorA._indices()[1]).long()
335344

336345
# Zero out the outliers in the int8 inputs
337346
CA[:, state.idx] = 0
338-
# CAt[:, state.idx] = 0
347+
348+
if CAt is not None:
349+
CAt[:, state.idx] = 0
339350

340351
# Extract the input outliers in original precision
341352
subA = A[:, state.idx]
@@ -372,7 +383,7 @@ def forward(
372383
ctx.tensors = (CAt, subA, A)
373384
ctx.tensor_states = (SCAt, state.idx)
374385
else:
375-
ctx.tensors = [None, None, None] # A]
386+
ctx.tensors = [None, None, None]
376387
ctx.tensor_states = (None, None)
377388
ctx.save_for_backward(None, None)
378389

@@ -403,17 +414,16 @@ def backward(ctx, grad_output):
403414
if len(grad_output.shape) == 3:
404415
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
405416

406-
Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
407-
# if req_gradB:
408-
409-
# grad_B = torch.matmul(grad_output.t(), A)
410-
# if state.threshold > 0.0 and subA is not None:
411-
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
412417
if req_gradB:
413-
gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
414-
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
418+
grad_B = torch.matmul(grad_output.t(), A)
415419
if state.threshold > 0.0 and subA is not None:
416420
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
421+
# Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
422+
# if req_gradB:
423+
# gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
424+
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
425+
# if state.threshold > 0.0 and subA is not None:
426+
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
417427

418428
if req_gradA:
419429
# grad_output @ B.T

bitsandbytes/functional.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,6 +2409,7 @@ def get_colrow_absmax(
24092409

24102410
if row_stats is None:
24112411
# shape [rows]; unsqueeze(-1) gives [rows,1]
2412+
# We have a CUDA kernel for row max, but not yet for cols.
24122413
row_stats = get_row_absmax(A, threshold)
24132414

24142415
if col_stats is None:
@@ -2521,29 +2522,42 @@ def extract_outliers_new(A: torch.Tensor, threshold: float):
25212522

25222523

25232524
def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
2525+
# TODO: Optimize/write CUDA kernel for this?
2526+
# Note: for inference, use the new int8_vectorwise_quant.
2527+
2528+
# Use CUDA kernel for rowwise and COO tensor
2529+
quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold)
2530+
2531+
# PyTorch impl for colwise
2532+
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
2533+
if threshold > 0.0 and outlier_mask is not None:
2534+
A = A.masked_fill(outlier_mask, 0.0)
2535+
quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8)
2536+
2537+
if out_row is not None:
2538+
quant_row = out_row.copy_(quant_row)
2539+
if out_col is not None:
2540+
quant_col = out_col.copy_(quant_col)
2541+
2542+
return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
2543+
2544+
2545+
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25242546
assert A.dtype == torch.half
2547+
is_on_gpu([A])
25252548

25262549
rows = prod(A.shape[:-1])
25272550
cols = A.shape[-1]
25282551

25292552
row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32)
2530-
25312553
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
25322554

25332555
if threshold > 0.0:
2534-
# Extract outliers to COO tensor:
2535-
# 1. Zero out all of the non-outliers, convert to COO.
2536-
# 2. Zero out the outliers in the dense tensor.
25372556
# TODO we could improve perf of this
2538-
# outlier_mask = A.abs() >= threshold
2539-
# coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo()
2540-
# A = A.masked_fill(outlier_mask, 0.0)
25412557
coo_tensor = extract_outliers_new(A, threshold)
25422558
else:
25432559
coo_tensor = None
25442560

2545-
is_on_gpu([A, row_stats])
2546-
25472561
with torch.cuda.device_of(A):
25482562
lib.cint8_vector_quant(
25492563
get_ptr(A),
@@ -2554,9 +2568,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None,
25542568
ct.c_int32(cols),
25552569
)
25562570

2557-
# TODO: col_stats
2558-
2559-
return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor
2571+
return out_row, row_stats, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor
25602572

25612573

25622574
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):

csrc/kernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3612,7 +3612,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
36123612
#pragma unroll
36133613
for(int k = 0; k < num_values_8bit/4; k++)
36143614
{
3615-
#if __CUDA_ARCH__ >= 800
3615+
#if BNB_BF16_AVAILABLE
36163616
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
36173617
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
36183618
#else
@@ -3649,7 +3649,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
36493649
#pragma unroll
36503650
for(int k = 0; k < num_values_4bit/4; k++)
36513651
{
3652-
#if __CUDA_ARCH__ >= 800
3652+
#if BNB_BF16_AVAILABLE
36533653
local_C += (float)(local_A[k]*local_B[k]);
36543654
#else
36553655
// bf16 multipliation not supported

tests/test_autograd.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,16 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
253253
if not has_fp16_weights:
254254
if not transpose[0] and not transpose[1]:
255255
B2 = B2.t().contiguous()
256-
(
257-
state.CB,
258-
CBt,
259-
state.SCB,
260-
SCBt,
261-
coo_tensorB,
262-
) = bnb.functional.double_quant(B2.to(torch.float16))
256+
257+
state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16))
258+
259+
# (
260+
# state.CB,
261+
# CBt,
262+
# state.SCB,
263+
# SCBt,
264+
# coo_tensorB,
265+
# ) = bnb.functional.double_quant(B2.to(torch.float16))
263266
B2 = state.CB
264267

265268
if not transpose[0] and transpose[1]:

tests/test_functional.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,17 +1132,37 @@ def test_overflow():
11321132
c2 = torch.matmul(a.float(), b.float().t())
11331133

11341134

1135+
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
1136+
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
1137+
def test_coo_double_quant(dim1, dim2):
1138+
threshold = 2.00
1139+
for i in range(k):
1140+
A = torch.randn(dim1, dim2, device="cuda").half()
1141+
1142+
idx = torch.abs(A) >= threshold
1143+
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
1144+
1145+
if coo_tensor is not None:
1146+
A1 = A * idx
1147+
A2 = coo_tensor.to_dense()
1148+
torch.testing.assert_close(A1, A2)
1149+
1150+
A1 = A * (idx == 0)
1151+
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1152+
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
1153+
1154+
11351155
# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
11361156
# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
11371157
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
11381158
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
1139-
def test_coo_double_quant(dim1, dim2):
1159+
def test_coo_int8_vectorwise_quant(dim1, dim2):
11401160
threshold = 3.00
11411161
for i in range(k):
11421162
A = torch.randn(dim1, dim2, device="cuda").half()
11431163

11441164
idx = torch.abs(A) >= threshold
1145-
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
1165+
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)
11461166

11471167
if coo_tensor is not None:
11481168
A1 = A * idx
@@ -1239,13 +1259,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
12391259
w1 = torch.randn(dim1, dim2).cuda().half()
12401260
out1 = torch.matmul(A, w1.t())
12411261

1242-
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
1243-
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1262+
Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1)
1263+
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A)
12441264

12451265
out1_32, Sout1_32 = F.igemmlt(CA, Cw1)
12461266
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
12471267

1248-
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
1268+
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)
12491269

12501270
out1_32, Sout1_32 = F.igemmlt(CA, Cw1)
12511271
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

tests/test_linear8bitlt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ def test_linear_no_igemmlt():
7272

7373
assert linear_custom.state.CB is not None
7474
assert not linear_custom.state.has_fp16_weights
75-
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
76-
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
7775

78-
# assert linear_custom.state.CxB is None
76+
idx = torch.isclose(fx_ref, fx_ours, atol=0.02, rtol=1e-5)
77+
assert (idx == 0).sum().item() < fx_ref.numel() * 2.5e-4
78+
torch.testing.assert_close(fx_ref, fx_ours, atol=0.03, rtol=1e-5)
79+
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
7980

8081

8182
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))

0 commit comments

Comments
 (0)