@@ -245,11 +245,11 @@ class MatmulLtState:
245245 _tile_indices : Optional [torch .Tensor ] = None
246246 force_no_igemmlt : bool = False
247247 CB = None
248- CxB = None
248+ CxB = None # TODO: Deprecate/remove
249249 SB = None
250250 SCB = None
251251
252- CxBt = None
252+ CxBt = None # TODO: Deprecate/remove
253253 SBt = None
254254 CBt = None
255255
@@ -263,7 +263,7 @@ class MatmulLtState:
263263 has_fp16_weights = True
264264 memory_efficient_backward = False
265265 use_pool = False
266- formatB = F .get_special_format_str ()
266+ formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove
267267
268268 def reset_grads (self ):
269269 self .CB = None
@@ -283,9 +283,6 @@ def tile_indices(self):
283283
284284
285285class MatMul8bitLt (torch .autograd .Function ):
286- # forward is the same, but we added the fallback for pre-turing GPUs
287- # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
288-
289286 @staticmethod
290287 def forward (ctx , A , B , out = None , bias = None , state = MatmulLtState ):
291288 using_igemmlt = supports_igemmlt (A .device ) and not state .force_no_igemmlt
@@ -306,7 +303,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
306303 # 3. Matmul
307304 # 4. Mixed-precision decomposition matmul
308305 # 5. Save state
309- formatB = state .formatB
310306 input_shape = A .shape
311307 if state .outlier_pool is None :
312308 state .outlier_pool = GlobalOutlierPooler .get_instance ()
@@ -328,14 +324,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
328324 subA = A [:, idx ]
329325 state .subB = B [:, idx ].t ().contiguous ()
330326 state .idx = idx
331- else :
332- if state .CxB is None and using_igemmlt :
333- # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
334- # we also need to convert it to the turing/ampere format
335- state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
336327 else :
337- if not state .has_fp16_weights and state .CxB is None and using_igemmlt :
338- state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
339328 subA = None
340329
341330 # 2. Quantize B
@@ -345,19 +334,17 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
345334 if is_transposed :
346335 B = B .contiguous ()
347336
348- if (state .is_training and not has_grad ) or state .CxB is None :
337+ if (state .is_training and not has_grad ) or state .CB is None :
349338 state .reset_grads ()
339+
340+ # quantize...
350341 (
351- CB ,
342+ state . CB ,
352343 state .CBt ,
353344 state .SCB ,
354345 state .SCBt ,
355346 coo_tensorB ,
356347 ) = F .double_quant (B .to (torch .float16 ))
357- if using_igemmlt :
358- state .CxB , state .SB = F .transform (CB , to_order = formatB )
359- else :
360- state .CB = CB
361348 else :
362349 has_grad = False
363350
@@ -372,17 +359,18 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
372359 # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
373360 # else:
374361 # state.idx = outlier_idx
375- if state .CxB is not None :
376- outliers = F .extract_outliers (state .CxB , state .SB , state .idx .int ())
377- else :
378- outliers = state .CB [:, state .idx .long ()].clone ()
362+
363+ # if state.CxB is not None:
364+ # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
365+ # else:
366+ outliers = state .CB [:, state .idx .long ()].clone ()
379367
380368 state .subB = (outliers * state .SCB .view (- 1 , 1 ) / 127.0 ).t ().contiguous ().to (A .dtype )
381369 CA [:, state .idx .long ()] = 0
382370 CAt [:, state .idx .long ()] = 0
383371 subA = A [:, state .idx .long ()]
384372
385- shapeB = state .SB [ 0 ] if state . SB else B .shape
373+ shapeB = state .CB .shape
386374
387375 if len (input_shape ) == 3 :
388376 output_shape = (input_shape [0 ], input_shape [1 ], shapeB [0 ])
@@ -391,13 +379,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
391379
392380 # 3. Matmul
393381 if using_igemmlt :
394- C32A , SA = F .transform (CA , "col32" )
395- out32 , Sout32 = F . igemmlt ( C32A , state . CxB , SA , state . SB )
382+ out32 , Sout32 = F .igemmlt (CA , state . CB )
383+
396384 if bias is None or bias .dtype == torch .float16 :
397385 # we apply the fused bias here
398386 output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
399387 output = output .to (A .dtype )
400388 else : # apply bias separately
389+ # TODO: Fused bias for fp32/bf16?
401390 output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = None )
402391 output = output .to (A .dtype ).add_ (bias )
403392
@@ -417,7 +406,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
417406 # 5. Save state
418407 ctx .state = state
419408
420- ctx .formatB = formatB
421409 ctx .grad_shape = input_shape
422410 ctx .dtype_A , ctx .dtype_B , ctx .dtype_bias = A .dtype , B .dtype , None if bias is None else bias .dtype
423411
@@ -437,10 +425,10 @@ def backward(ctx, grad_output):
437425 if ctx .is_empty :
438426 bias_grad = None if ctx .bias is None else torch .zeros_like (ctx .bias )
439427 return torch .zeros_like (ctx .A ), torch .zeros_like (ctx .B ), None , bias_grad , None
428+
440429 req_gradA , req_gradB , _ , req_gradBias , _ = ctx .needs_input_grad
441430 CAt , subA , A = ctx .tensors
442431 SCAt , idx = ctx .tensor_states
443- formatB = ctx .formatB
444432 state = ctx .state
445433 grad_A = grad_B = grad_bias = None
446434
@@ -454,33 +442,39 @@ def backward(ctx, grad_output):
454442
455443 Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output .to (torch .float16 ))
456444 if req_gradB :
457- CxAt , SAt = F .transform (CAt , formatB , transpose = True )
458- C32grad , Sgrad = F .transform (Cgradt , "col32" , transpose = True )
459- gradB32 , SgradB32 = F .igemmlt (C32grad , CxAt , Sgrad , SAt )
445+ # CxAt, SAt = F.transform(CAt, formatB, transpose=True)
446+ # C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
447+ # gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
448+ # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
449+ gradB32 , SgradB32 = F .igemmlt (
450+ Cgradt .t (), CAt .t ()
451+ ) # issue here in test_linear_serialization w/ has fp16 weights
460452 grad_B = F .mm_dequant (gradB32 , SgradB32 , SCgradt , SCAt )
461453 if state .threshold > 0.0 and subA is not None :
462454 grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
463455
464456 if req_gradA :
465457 if state .CBt is not None :
466- C32grad , Sgrad = F .transform (Cgrad , "col32" )
467- if state .CxBt is None :
468- state .CxBt , state .SBt = F .transform (state .CBt , to_order = formatB , transpose = True )
469- gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
458+ # C32grad, Sgrad = F.transform(Cgrad, "col32")
459+ # if state.CxBt is None:
460+ # state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
461+ # gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
462+ # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
463+ gradA32 , SgradA32 = F .igemmlt (Cgradt , state .CBt .t ())
470464 grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape ).to (ctx .dtype_A )
471465
472466 elif state .CB is not None :
473467 CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
474468 grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
475- elif state .CxB is not None :
476- CB = (
477- undo_layout (state .CxB , state .tile_indices )
478- .to (ctx .dtype_A )
479- .mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
480- )
481- grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
469+ # elif state.CxB is not None:
470+ # CB = (
471+ # undo_layout(state.CxB, state.tile_indices)
472+ # .to(ctx.dtype_A)
473+ # .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
474+ # )
475+ # grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
482476 else :
483- raise Exception ("State must contain either CBt or CB or CxB matrix for backward" )
477+ raise Exception ("State must contain either CBt or CB matrix for backward" )
484478
485479 return grad_A , grad_B , None , grad_bias , None
486480
@@ -564,6 +558,7 @@ def matmul_4bit(
564558 bias = None ,
565559):
566560 assert quant_state is not None
561+
567562 if A .numel () == A .shape [- 1 ] and A .requires_grad == False :
568563 if A .shape [- 1 ] % quant_state .blocksize != 0 :
569564 warn (
0 commit comments