Skip to content

Commit 0cc5c95

Browse files
Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation
1 parent 0500c31 commit 0cc5c95

File tree

10 files changed

+382
-243
lines changed

10 files changed

+382
-243
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ class MatmulLtState:
245245
_tile_indices: Optional[torch.Tensor] = None
246246
force_no_igemmlt: bool = False
247247
CB = None
248-
CxB = None
248+
CxB = None # TODO: Deprecate/remove
249249
SB = None
250250
SCB = None
251251

252-
CxBt = None
252+
CxBt = None # TODO: Deprecate/remove
253253
SBt = None
254254
CBt = None
255255

@@ -263,7 +263,7 @@ class MatmulLtState:
263263
has_fp16_weights = True
264264
memory_efficient_backward = False
265265
use_pool = False
266-
formatB = F.get_special_format_str()
266+
formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove
267267

268268
def reset_grads(self):
269269
self.CB = None
@@ -283,9 +283,6 @@ def tile_indices(self):
283283

284284

285285
class MatMul8bitLt(torch.autograd.Function):
286-
# forward is the same, but we added the fallback for pre-turing GPUs
287-
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
288-
289286
@staticmethod
290287
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
291288
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
@@ -306,7 +303,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
306303
# 3. Matmul
307304
# 4. Mixed-precision decomposition matmul
308305
# 5. Save state
309-
formatB = state.formatB
310306
input_shape = A.shape
311307
if state.outlier_pool is None:
312308
state.outlier_pool = GlobalOutlierPooler.get_instance()
@@ -328,14 +324,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
328324
subA = A[:, idx]
329325
state.subB = B[:, idx].t().contiguous()
330326
state.idx = idx
331-
else:
332-
if state.CxB is None and using_igemmlt:
333-
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
334-
# we also need to convert it to the turing/ampere format
335-
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
336327
else:
337-
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
338-
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
339328
subA = None
340329

341330
# 2. Quantize B
@@ -345,19 +334,17 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
345334
if is_transposed:
346335
B = B.contiguous()
347336

348-
if (state.is_training and not has_grad) or state.CxB is None:
337+
if (state.is_training and not has_grad) or state.CB is None:
349338
state.reset_grads()
339+
340+
# quantize...
350341
(
351-
CB,
342+
state.CB,
352343
state.CBt,
353344
state.SCB,
354345
state.SCBt,
355346
coo_tensorB,
356347
) = F.double_quant(B.to(torch.float16))
357-
if using_igemmlt:
358-
state.CxB, state.SB = F.transform(CB, to_order=formatB)
359-
else:
360-
state.CB = CB
361348
else:
362349
has_grad = False
363350

@@ -372,17 +359,18 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
372359
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
373360
# else:
374361
# state.idx = outlier_idx
375-
if state.CxB is not None:
376-
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
377-
else:
378-
outliers = state.CB[:, state.idx.long()].clone()
362+
363+
# if state.CxB is not None:
364+
# outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
365+
# else:
366+
outliers = state.CB[:, state.idx.long()].clone()
379367

380368
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
381369
CA[:, state.idx.long()] = 0
382370
CAt[:, state.idx.long()] = 0
383371
subA = A[:, state.idx.long()]
384372

385-
shapeB = state.SB[0] if state.SB else B.shape
373+
shapeB = state.CB.shape
386374

387375
if len(input_shape) == 3:
388376
output_shape = (input_shape[0], input_shape[1], shapeB[0])
@@ -391,13 +379,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
391379

392380
# 3. Matmul
393381
if using_igemmlt:
394-
C32A, SA = F.transform(CA, "col32")
395-
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
382+
out32, Sout32 = F.igemmlt(CA, state.CB)
383+
396384
if bias is None or bias.dtype == torch.float16:
397385
# we apply the fused bias here
398386
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
399387
output = output.to(A.dtype)
400388
else: # apply bias separately
389+
# TODO: Fused bias for fp32/bf16?
401390
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
402391
output = output.to(A.dtype).add_(bias)
403392

@@ -417,7 +406,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
417406
# 5. Save state
418407
ctx.state = state
419408

420-
ctx.formatB = formatB
421409
ctx.grad_shape = input_shape
422410
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
423411

@@ -437,10 +425,10 @@ def backward(ctx, grad_output):
437425
if ctx.is_empty:
438426
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
439427
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
428+
440429
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
441430
CAt, subA, A = ctx.tensors
442431
SCAt, idx = ctx.tensor_states
443-
formatB = ctx.formatB
444432
state = ctx.state
445433
grad_A = grad_B = grad_bias = None
446434

@@ -454,33 +442,39 @@ def backward(ctx, grad_output):
454442

455443
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
456444
if req_gradB:
457-
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
458-
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
459-
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
445+
# CxAt, SAt = F.transform(CAt, formatB, transpose=True)
446+
# C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
447+
# gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
448+
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
449+
gradB32, SgradB32 = F.igemmlt(
450+
Cgradt.t(), CAt.t()
451+
) # issue here in test_linear_serialization w/ has fp16 weights
460452
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
461453
if state.threshold > 0.0 and subA is not None:
462454
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
463455

464456
if req_gradA:
465457
if state.CBt is not None:
466-
C32grad, Sgrad = F.transform(Cgrad, "col32")
467-
if state.CxBt is None:
468-
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
469-
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
458+
# C32grad, Sgrad = F.transform(Cgrad, "col32")
459+
# if state.CxBt is None:
460+
# state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
461+
# gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
462+
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
463+
gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t())
470464
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
471465

472466
elif state.CB is not None:
473467
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
474468
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
475-
elif state.CxB is not None:
476-
CB = (
477-
undo_layout(state.CxB, state.tile_indices)
478-
.to(ctx.dtype_A)
479-
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
480-
)
481-
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
469+
# elif state.CxB is not None:
470+
# CB = (
471+
# undo_layout(state.CxB, state.tile_indices)
472+
# .to(ctx.dtype_A)
473+
# .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
474+
# )
475+
# grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
482476
else:
483-
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
477+
raise Exception("State must contain either CBt or CB matrix for backward")
484478

485479
return grad_A, grad_B, None, grad_bias, None
486480

@@ -564,6 +558,7 @@ def matmul_4bit(
564558
bias=None,
565559
):
566560
assert quant_state is not None
561+
567562
if A.numel() == A.shape[-1] and A.requires_grad == False:
568563
if A.shape[-1] % quant_state.blocksize != 0:
569564
warn(

0 commit comments

Comments
 (0)