File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -370,8 +370,6 @@ def backward(ctx, grad_output):
370370 if state .threshold > 0.0 and subA is not None :
371371 grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
372372
373- raise NotImplementedError ("!!" )
374-
375373 if req_gradA :
376374 if state .CBt is not None :
377375 C32grad , Sgrad = F .transform (Cgrad , "col32" )
Original file line number Diff line number Diff line change @@ -237,7 +237,9 @@ def __init__(
237237 if threshold > 0.0 and not has_fp16_weights :
238238 self .state .use_pool = True
239239
240- self .weight = Int8Params (self .weight .data , has_fp16_weights = has_fp16_weights )
240+ self .weight = Int8Params (
241+ self .weight .data , has_fp16_weights = has_fp16_weights , requires_grad = has_fp16_weights
242+ )
241243
242244 def init_8bit_state (self ):
243245 self .state .CB = self .weight .CB
You can’t perform that action at this time.
0 commit comments