11import operator
2+ import warnings
3+
24import torch
35import bitsandbytes .functional as F
46
@@ -184,6 +186,7 @@ class MatmulLtState:
184186 idx = None
185187 is_training = True
186188 has_fp16_weights = True
189+ memory_efficient_backward = False
187190 use_pool = False
188191 formatB = F .get_special_format_str ()
189192
@@ -209,31 +212,29 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
209212 ctx .B = B
210213 ctx .bias = bias
211214 if A .shape [- 1 ] == B .shape [0 ]:
212- return torch .empty (A .shape [:- 1 ]+ B .shape [1 :], dtype = torch . float16 , device = A .device )
215+ return torch .empty (A .shape [:- 1 ]+ B .shape [1 :], dtype = A . dtype , device = A .device )
213216 else :
214- return torch .empty (A .shape [:- 1 ]+ B .shape [:1 ], dtype = torch . float16 , device = A .device )
217+ return torch .empty (A .shape [:- 1 ]+ B .shape [:1 ], dtype = A . dtype , device = A .device )
215218
216219 # 1. Quantize A
217220 # 2. Quantize B
218221 # 3. Matmul
219222 # 4. Mixed-precision decomposition matmul
220223 # 5. Save state
221- requires_gradA = A .requires_grad
222- requires_gradB = B .requires_grad
223- requires_gradBias = bias is not None and bias .requires_grad
224224 formatB = state .formatB
225225 input_shape = A .shape
226226 if state .outlier_pool is None :
227227 state .outlier_pool = GlobalOutlierPooler .get_instance ()
228- assert (
229- A .dtype == torch .float16
230- ), f"The input data type needs to be fp16 but { A .dtype } was found!"
228+
229+ # Cast A to fp16
230+ if A .dtype != torch .float16 :
231+ warnings .warn (f"MatMul8bitLt: inputs will be cast from { A .dtype } to float16 during quantization" )
231232
232233 # 1. Quantize A
233234 if len (A .shape ) == 3 :
234235 A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
235236 CA , CAt , SCA , SCAt , coo_tensorA = F .double_quant (
236- A , threshold = state .threshold
237+ A . to ( torch . float16 ) , threshold = state .threshold
237238 )
238239
239240 if state .threshold > 0.0 and coo_tensorA is not None :
@@ -269,7 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
269270 state .SCB ,
270271 state .SCBt ,
271272 coo_tensorB ,
272- ) = F .double_quant (B )
273+ ) = F .double_quant (B . to ( torch . float16 ) )
273274 state .CxB , state .SB = F .transform (CB , to_order = formatB )
274275 else :
275276 has_grad = False
@@ -290,7 +291,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
290291 (outliers * state .SCB .view (- 1 , 1 ) / 127.0 )
291292 .t ()
292293 .contiguous ()
293- .half ( )
294+ .to ( A . dtype )
294295 )
295296 CA [:, state .idx .long ()] = 0
296297 CAt [:, state .idx .long ()] = 0
@@ -307,7 +308,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
307308 C32A , SA = F .transform (CA , "col32" )
308309 out32 , Sout32 = F .igemmlt (C32A , state .CxB , SA , state .SB )
309310 # we apply the fused bias here
310- output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
311+
312+ if bias is None or bias .dtype == torch .float16 :
313+ output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
314+ output = output .to (A .dtype )
315+ else : # apply bias separately
316+ output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = None )
317+ output = output .to (A .dtype ).add_ (bias )
311318
312319 # 4. Mixed-precision decomposition matmul
313320 if coo_tensorA is not None and subA is not None :
@@ -318,42 +325,43 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
318325
319326 ctx .formatB = formatB
320327 ctx .grad_shape = input_shape
321- ctx .req_grads = [ requires_gradA , requires_gradB , requires_gradBias ]
328+ ctx .dtype_A , ctx . dtype_B , ctx . dtype_bias = A . dtype , B . dtype , None if bias is None else bias . dtype
322329
323- if requires_gradA or requires_gradB :
330+ if any ( ctx . needs_input_grad [: 2 ]) :
324331 ctx .tensors = (CAt , subA )
325332 ctx .tensor_states = (SCAt , state .idx )
326333 else :
327334 ctx .tensors = [None , None ]
328335 ctx .tensor_states = (None , None )
329336 ctx .save_for_backward (None , None )
330337
338+
331339 clone_func = torch .clone if len (output_shape ) == 3 else lambda x : x
332- #clone_func = torch.clone
333340 return clone_func (output .view (output_shape ))
334341
335342 @staticmethod
336343 def backward (ctx , grad_output ):
337344 if ctx .is_empty :
338345 bias_grad = (None if ctx .bias is None else torch .zeros_like (ctx .bias ))
339346 return torch .zeros_like (ctx .A ), torch .zeros_like (ctx .B ), None , bias_grad , None
340- req_gradA , req_gradB , req_gradBias = ctx .req_grads
347+ req_gradA , req_gradB , _ , req_gradBias , _ = ctx .needs_input_grad
341348 CAt , subA = ctx .tensors
342349 SCAt , idx = ctx .tensor_states
343350 formatB = ctx .formatB
344351 state = ctx .state
345- assert (
346- state .has_fp16_weights
347- ), "Backprop only supported for fp16 weights."
352+ grad_A = grad_B = grad_bias = None
353+
354+ if req_gradBias :
355+ # compute grad_bias first before changing grad_output dtype
356+ grad_bias = grad_output .sum (0 , dtype = ctx .dtype_bias )
348357
358+ # Cast grad_output to fp16
349359 if len (grad_output .shape ) == 3 :
350- grad_output = grad_output .view (
360+ grad_output = grad_output .reshape (
351361 - 1 , grad_output .shape [- 1 ]
352362 ).contiguous ()
353363
354- grad_A = grad_B = grad_bias = None
355-
356- Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output )
364+ Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output .to (torch .float16 ))
357365 if req_gradB :
358366 CxAt , SAt = F .transform (CAt , formatB , transpose = True )
359367 C32grad , Sgrad = F .transform (Cgradt , "col32" , transpose = True )
@@ -363,16 +371,20 @@ def backward(ctx, grad_output):
363371 grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
364372
365373 if req_gradA :
366- C32grad , Sgrad = F .transform (Cgrad , "col32" )
367- if state .CxBt is None :
368- state .CxBt , state .SBt = F .transform (
369- state .CBt , to_order = formatB , transpose = True
370- )
371- gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
372- grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape )
374+ if state .CBt is not None :
375+ C32grad , Sgrad = F .transform (Cgrad , "col32" )
376+ if state .CxBt is None :
377+ state .CxBt , state .SBt = F .transform (
378+ state .CBt , to_order = formatB , transpose = True
379+ )
380+ gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
381+ grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape ).to (ctx .dtype_A )
373382
374- if req_gradBias :
375- grad_bias = grad_output .sum (0 )
383+ elif state .CB is not None :
384+ CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).mul (1. / 127.0 ))
385+ grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
386+ else :
387+ raise Exception ('State must contain either CBt or CB matrix for backward' )
376388
377389 return grad_A , grad_B , None , grad_bias , None
378390
0 commit comments