Skip to content

Commit 50fe50e

Browse files
New naive mm_dequant kernel for row-major; cleanup
1 parent 0f2dc34 commit 50fe50e

File tree

9 files changed

+189
-276
lines changed

9 files changed

+189
-276
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def tile_indices(self):
284284

285285
class MatMul8bitLt(torch.autograd.Function):
286286
@staticmethod
287-
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
287+
def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
288+
state = state or MatmulLtState()
289+
288290
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
289291
# default of pytorch behavior if inputs are empty
290292
ctx.is_empty = False
@@ -417,8 +419,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
417419
ctx.tensor_states = (None, None)
418420
ctx.save_for_backward(None, None)
419421

420-
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
421-
return clone_func(output.view(output_shape))
422+
return output.reshape(output_shape)
422423

423424
@staticmethod
424425
def backward(ctx, grad_output):
@@ -442,37 +443,18 @@ def backward(ctx, grad_output):
442443

443444
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
444445
if req_gradB:
445-
# CxAt, SAt = F.transform(CAt, formatB, transpose=True)
446-
# C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
447-
# gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
448-
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
449-
gradB32, SgradB32 = F.igemmlt(
450-
Cgradt.t(), CAt.t()
451-
) # issue here in test_linear_serialization w/ has fp16 weights
446+
gradB32, SgradB32 = F.igemmlt(Cgradt.t(), CAt.t())
452447
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
453448
if state.threshold > 0.0 and subA is not None:
454449
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
455450

456451
if req_gradA:
457452
if state.CBt is not None:
458-
# C32grad, Sgrad = F.transform(Cgrad, "col32")
459-
# if state.CxBt is None:
460-
# state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
461-
# gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
462-
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
463453
gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t())
464454
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
465-
466455
elif state.CB is not None:
467456
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
468457
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
469-
# elif state.CxB is not None:
470-
# CB = (
471-
# undo_layout(state.CxB, state.tile_indices)
472-
# .to(ctx.dtype_A)
473-
# .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
474-
# )
475-
# grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
476458
else:
477459
raise Exception("State must contain either CBt or CB matrix for backward")
478460

bitsandbytes/functional.py

Lines changed: 75 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,7 +2330,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
23302330
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
23312331
ldc = shapeC[-1] # Output (batch, tokens, outputs)
23322332

2333-
assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ B={shapeA}"
2333+
assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
23342334

23352335
prev_device = A.device
23362336
torch.cuda.set_device(A.device)
@@ -2361,18 +2361,25 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
23612361
return out, Sout
23622362

23632363

2364-
def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
2364+
def mm_dequant_torch(
2365+
A: torch.Tensor,
2366+
quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format)
2367+
row_stats: torch.Tensor,
2368+
col_stats: torch.Tensor,
2369+
out: Optional[torch.Tensor] = None,
2370+
new_row_stats=None, # TODO: unused
2371+
new_col_stats=None, # TODO: unused
2372+
bias: Optional[torch.Tensor] = None,
2373+
):
23652374
assert A.dtype == torch.int32
23662375

2367-
compute_dtype = torch.float32
2368-
2369-
A_calc = A.view(-1, A.shape[-1]).to(compute_dtype)
2370-
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
2371-
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
2376+
A_calc = A.view(-1, A.shape[-1])
2377+
row_stats = row_stats.reshape(-1).unsqueeze(-1)
2378+
col_stats = col_stats.reshape(-1).unsqueeze(0)
23722379

23732380
# TODO support out != None
23742381

2375-
out = A_calc * (row_stats * col_stats) * 6.200124e-5 # .to(torch.float16)
2382+
out = A_calc * (row_stats * col_stats) * 6.200124e-5
23762383

23772384
if bias is not None:
23782385
# assert bias.dtype == torch.float16
@@ -2381,42 +2388,40 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
23812388
return out.to(torch.float16)
23822389

23832390

2384-
def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
2391+
def mm_dequant(
2392+
A: torch.Tensor,
2393+
quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format)
2394+
row_stats: torch.Tensor,
2395+
col_stats: torch.Tensor,
2396+
out: Optional[torch.Tensor] = None,
2397+
new_row_stats=None, # TODO: unused
2398+
new_col_stats=None, # TODO: unused
2399+
bias: Optional[torch.Tensor] = None,
2400+
):
23852401
assert A.dtype == torch.int32
2402+
23862403
if bias is not None:
23872404
assert bias.dtype == torch.float16
2388-
out_shape = quant_state[0]
2389-
if len(out_shape) == 3:
2390-
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
23912405

23922406
if out is None:
2393-
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
2394-
if new_row_stats is None:
2395-
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
2396-
if new_col_stats is None:
2397-
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
2398-
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
2399-
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
2407+
out = torch.empty_like(A, dtype=torch.float16)
24002408

2401-
prev_device = pre_call(A.device)
24022409
ptrA = get_ptr(A)
24032410
ptrOut = get_ptr(out)
24042411
ptrRowStats = get_ptr(row_stats)
24052412
ptrColStats = get_ptr(col_stats)
2406-
ptrNewRowStats = get_ptr(new_row_stats)
2407-
ptrNewColStats = get_ptr(new_col_stats)
24082413
ptrBias = get_ptr(bias)
2409-
numRows = ct.c_int32(out_shape[0])
2410-
numCols = ct.c_int32(out_shape[1])
2414+
numRows = ct.c_int32(prod(A.shape[:-1]))
2415+
numCols = ct.c_int32(A.shape[-1])
2416+
2417+
is_on_gpu([A, row_stats, col_stats, out, bias])
24112418

2412-
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
2419+
prev_device = pre_call(A.device)
24132420
lib.cdequant_mm_int32_fp16(
24142421
ptrA,
24152422
ptrRowStats,
24162423
ptrColStats,
24172424
ptrOut,
2418-
ptrNewRowStats,
2419-
ptrNewColStats,
24202425
ptrBias,
24212426
numRows,
24222427
numCols,
@@ -2426,7 +2431,33 @@ def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats
24262431
return out
24272432

24282433

2429-
def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
2434+
def get_colrow_absmax(
2435+
A: torch.Tensor,
2436+
row_stats: torch.Tensor = None,
2437+
col_stats: torch.Tensor = None,
2438+
nnz_block_ptr: torch.Tensor = None,
2439+
threshold=0.0,
2440+
):
2441+
# Note: prior impl only works with fp16
2442+
assert A.is_floating_point()
2443+
2444+
if row_stats is None or col_stats is None:
2445+
absA = A.abs().view(-1, A.shape[-1]) # view as 2D
2446+
if row_stats is None:
2447+
# shape [rows]; unsqueeze(-1) gives [rows,1]
2448+
row_stats = absA.amax(dim=1, keepdim=False).float()
2449+
if col_stats is None:
2450+
# shape [cols]; unsqueeze(0) gives [1,cols]
2451+
col_stats = absA.amax(dim=0, keepdim=False).float()
2452+
2453+
# TODO: threshold support
2454+
if nnz_block_ptr is None and threshold > 0.0:
2455+
nnz_block_ptr = torch.zeros_like(A, dtype=torch.int32)
2456+
2457+
return row_stats, col_stats, nnz_block_ptr
2458+
2459+
2460+
def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
24302461
assert A.dtype == torch.float16
24312462
device = A.device
24322463

@@ -2543,19 +2574,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
25432574
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
25442575

25452576

2577+
@torch.compile
25462578
def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
2547-
# TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats.
2579+
# TODO: Optimize/write CUDA kernel for this
25482580
# TODO: Support threshold
25492581

2550-
# if out_col is None:
2551-
# out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8)
2552-
# if out_row is None:
2553-
# out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8)
2582+
if row_stats is None or col_stats is None:
2583+
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
2584+
2585+
scaled_A = A.mul(C)
2586+
2587+
# quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8)
2588+
# quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8)
2589+
quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8)
2590+
quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8)
25542591

2555-
out_col, Scol = vectorwise_quant(A, dim=0)
2556-
out_row, Srow = vectorwise_quant(A, dim=1)
2592+
if out_row is not None:
2593+
quant_row = out_row.copy_(quant_row)
2594+
if out_col is not None:
2595+
quant_col = out_col.copy_(quant_col)
25572596

2558-
return out_row, out_col, Srow.flatten().float(), Scol.flatten().float(), None # coo_tensor
2597+
return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), None
25592598

25602599

25612600
def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):

0 commit comments

Comments
 (0)