Skip to content

Commit 521da0c

Browse files
int8 sparse decomp: small perf improvement
1 parent 32979b4 commit 521da0c

File tree

3 files changed

+16
-21
lines changed

3 files changed

+16
-21
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,10 @@ def forward(
319319

320320
if ctx.needs_input_grad[1]:
321321
# Slower path
322-
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
322+
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
323323
else:
324324
# Fast path
325-
CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
325+
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
326326
CAt = SCAt = None
327327

328328
has_grad = False
@@ -339,8 +339,8 @@ def forward(
339339
# 2. Quantize B
340340
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
341341

342-
if state.threshold > 0.0 and coo_tensorA is not None:
343-
state.idx = torch.unique(coo_tensorA._indices()[1]).long()
342+
if state.threshold > 0.0 and outlier_cols is not None:
343+
state.idx = outlier_cols
344344

345345
# Zero out the outliers in the transposed 8bit inputs.
346346
if CAt is not None:

bitsandbytes/functional.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,7 +2546,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None,
25462546
# Note: for inference, use the new int8_vectorwise_quant.
25472547

25482548
# Use CUDA kernel for rowwise and COO tensor
2549-
quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold)
2549+
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)
25502550

25512551
# PyTorch impl for colwise
25522552
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
@@ -2559,7 +2559,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None,
25592559
if out_col is not None:
25602560
quant_col = out_col.copy_(quant_col)
25612561

2562-
return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
2562+
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
25632563

25642564

25652565
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
@@ -2574,13 +2574,9 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25742574

25752575
if threshold > 0.0:
25762576
# TODO we could improve perf of this
2577-
2578-
# A.masked_fill(A.abs() < threshold, 0.0).to_sparse_coo()
2579-
# coo_tensor = extract_outliers_new(A, threshold)
2580-
coo_tensor = torch.masked_fill(A, A.abs() < threshold, 0.0).to_sparse_coo()
2581-
2577+
outlier_cols = torch.argwhere((A.abs() >= threshold).any(dim=0)).view(-1)
25822578
else:
2583-
coo_tensor = None
2579+
outlier_cols = None
25842580

25852581
with torch.cuda.device_of(A):
25862582
lib.cint8_vector_quant(
@@ -2593,7 +2589,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25932589
get_tensor_stream(A),
25942590
)
25952591

2596-
return out_row, row_stats, coo_tensor
2592+
return out_row, row_stats, outlier_cols
25972593

25982594

25992595
@deprecated(

bitsandbytes/research/autograd/_functions.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,11 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
217217
# 1. Quantize A
218218
if len(A.shape) == 3:
219219
A = A.view(-1, A.shape[-1]).contiguous()
220-
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
220+
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
221221

222-
if state.threshold > 0.0 and coo_tensorA is not None:
222+
if state.threshold > 0.0 and outlier_cols is not None:
223223
if state.has_fp16_weights:
224-
# idx = torch.unique(coo_tensorA.colidx).long()
225-
idx = torch.unique(coo_tensorA._indices()[1]).long()
224+
idx = outlier_cols
226225
CA[:, idx] = 0
227226
# CAt[:, idx] = 0
228227
subA = A[:, idx]
@@ -257,9 +256,9 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
257256
else:
258257
has_grad = False
259258

260-
if coo_tensorA is not None and not state.has_fp16_weights:
259+
if outlier_cols is not None and not state.has_fp16_weights:
261260
# extract outliers
262-
state.idx = torch.unique(coo_tensorA._indices()[1]).long()
261+
state.idx = outlier_cols
263262

264263
# outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
265264
outliers = state.CB[:, state.idx.long()].clone()
@@ -287,7 +286,7 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
287286
output = output.to(A.dtype).add_(bias)
288287

289288
# 4. Mixed-precision decomposition matmul
290-
if coo_tensorA is not None and subA is not None:
289+
if outlier_cols is not None and subA is not None:
291290
output += torch.matmul(subA, state.subB)
292291

293292
# 5. Save state
@@ -327,7 +326,7 @@ def backward(ctx, grad_output):
327326
if len(grad_output.shape) == 3:
328327
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
329328

330-
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
329+
Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16))
331330

332331
if req_gradB:
333332
# print('back A shape', A.shape)

0 commit comments

Comments
 (0)