@@ -366,9 +366,8 @@ def backward(ctx, grad_output):
366366 CxAt , SAt = F .transform (CAt , formatB , transpose = True )
367367 C32grad , Sgrad = F .transform (Cgradt , "col32" , transpose = True )
368368 gradB32 , SgradB32 = F .igemmlt (C32grad , CxAt , Sgrad , SAt )
369- grad_B = F .mm_dequant (gradB32 , SgradB32 , SCgradt , SCAt ). to ( ctx . dtype_B )
369+ grad_B = F .mm_dequant (gradB32 , SgradB32 , SCgradt , SCAt )
370370 if state .threshold > 0.0 and subA is not None :
371- assert False , idx
372371 grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
373372
374373 if req_gradA :
@@ -382,8 +381,7 @@ def backward(ctx, grad_output):
382381 grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape ).to (ctx .dtype_A )
383382
384383 elif state .CB is not None :
385- CB = state .CB .to (ctx .dtype_B )
386- CB .mul_ (state .SCB .unsqueeze (1 ).div_ (127.0 ).to (CB .dtype ))
384+ CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).div (127.0 ))
387385 grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
388386 else :
389387 raise Exception ('State must contain either CBt or CB matrix for backward' )
0 commit comments