@@ -226,7 +226,7 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt):
226226 def __init__ (self , * args : Any , device : Optional [_DEVICE ] = None , threshold : float = 6.0 , ** kwargs : Any ) -> None :
227227 super ().__init__ (* args , device = device , threshold = threshold , ** kwargs )
228228 self .weight = cast (bnb .nn .Int8Params , self .weight ) # type: ignore[has-type]
229- self .bias : Optional [torch .nn .Parameter ] = self .bias
229+ self .bias : Optional [torch .nn .Parameter ] = self .bias
230230 # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
231231 # filling the device memory with float32 weights which could lead to OOM
232232 if torch .tensor (0 , device = device ).device .type == "cuda" :
@@ -310,7 +310,7 @@ class _Linear4bit(bnb.nn.Linear4bit):
310310 def __init__ (self , * args : Any , device : Optional [_DEVICE ] = None , ** kwargs : Any ) -> None :
311311 super ().__init__ (* args , device = device , ** kwargs )
312312 self .weight = cast (bnb .nn .Params4bit , self .weight ) # type: ignore[has-type]
313- self .bias : Optional [torch .nn .Parameter ] = self .bias
313+ self .bias : Optional [torch .nn .Parameter ] = self .bias
314314 # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
315315 # filling the device memory with float32 weights which could lead to OOM
316316 if torch .tensor (0 , device = device ).device .type == "cuda" :
0 commit comments