@@ -264,7 +264,7 @@ class MatmulLtState:
264264 has_fp16_weights = True
265265 memory_efficient_backward = False
266266 use_pool = False
267- formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove
267+ formatB = "row" # TODO: Deprecate/remove
268268
269269 def reset_grads (self ):
270270 self .CB = None
@@ -394,9 +394,9 @@ def forward(
394394 output_shape = (* input_shape [:- 1 ], state .CB .shape [0 ])
395395
396396 if len (input_shape ) == 3 :
397- return output .reshape (output_shape ) # .clone()
398- else :
399- return output
397+ return output .reshape (output_shape )
398+
399+ return output
400400
401401 @staticmethod
402402 def backward (ctx , grad_output ):
@@ -418,11 +418,6 @@ def backward(ctx, grad_output):
418418 if len (grad_output .shape ) == 3 :
419419 grad_output = grad_output .reshape (- 1 , grad_output .shape [- 1 ]).contiguous ()
420420
421- # if req_gradB:
422- # grad_B = torch.matmul(grad_output.t(), A)
423- # if state.threshold > 0.0 and subA is not None:
424- # grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
425- # Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
426421 if req_gradB :
427422 Cgrad , _ , _ , SCgradt , _ = F .double_quant (grad_output .to (torch .float16 ))
428423
@@ -432,15 +427,11 @@ def backward(ctx, grad_output):
432427 grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
433428
434429 if req_gradA :
435- # grad_output @ B.T
436- # if state.CBt is not None:
437- # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
438- # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
439430 if state .CB is not None :
440431 CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
441432 grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
442433 else :
443- raise Exception ("State must contain either CBt or CB matrix for backward" )
434+ raise Exception ("State must contain CB matrix for backward" )
444435
445436 return grad_A , grad_B , None , grad_bias , None
446437
0 commit comments