Skip to content

Commit 0aefeb0

Browse files
int8 more cleanup
1 parent 875414e commit 0aefeb0

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ class MatmulLtState:
264264
has_fp16_weights = True
265265
memory_efficient_backward = False
266266
use_pool = False
267-
formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove
267+
formatB = "row" # TODO: Deprecate/remove
268268

269269
def reset_grads(self):
270270
self.CB = None
@@ -394,9 +394,9 @@ def forward(
394394
output_shape = (*input_shape[:-1], state.CB.shape[0])
395395

396396
if len(input_shape) == 3:
397-
return output.reshape(output_shape) # .clone()
398-
else:
399-
return output
397+
return output.reshape(output_shape)
398+
399+
return output
400400

401401
@staticmethod
402402
def backward(ctx, grad_output):
@@ -418,11 +418,6 @@ def backward(ctx, grad_output):
418418
if len(grad_output.shape) == 3:
419419
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
420420

421-
# if req_gradB:
422-
# grad_B = torch.matmul(grad_output.t(), A)
423-
# if state.threshold > 0.0 and subA is not None:
424-
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
425-
# Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
426421
if req_gradB:
427422
Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
428423

@@ -432,15 +427,11 @@ def backward(ctx, grad_output):
432427
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
433428

434429
if req_gradA:
435-
# grad_output @ B.T
436-
# if state.CBt is not None:
437-
# gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
438-
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
439430
if state.CB is not None:
440431
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
441432
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
442433
else:
443-
raise Exception("State must contain either CBt or CB matrix for backward")
434+
raise Exception("State must contain CB matrix for backward")
444435

445436
return grad_A, grad_B, None, grad_bias, None
446437

0 commit comments

Comments
 (0)