@@ -226,7 +226,7 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt):
226
226
def __init__ (self , * args : Any , device : Optional [_DEVICE ] = None , threshold : float = 6.0 , ** kwargs : Any ) -> None :
227
227
super ().__init__ (* args , device = device , threshold = threshold , ** kwargs )
228
228
self .weight = cast (bnb .nn .Int8Params , self .weight ) # type: ignore[has-type]
229
- self .bias = cast ( Optional [torch .nn .Parameter ], self .bias ) # type: ignore[has-type]
229
+ self .bias : Optional [torch .nn .Parameter ] = self .bias
230
230
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
231
231
# filling the device memory with float32 weights which could lead to OOM
232
232
if torch .tensor (0 , device = device ).device .type == "cuda" :
@@ -310,7 +310,7 @@ class _Linear4bit(bnb.nn.Linear4bit):
310
310
def __init__ (self , * args : Any , device : Optional [_DEVICE ] = None , ** kwargs : Any ) -> None :
311
311
super ().__init__ (* args , device = device , ** kwargs )
312
312
self .weight = cast (bnb .nn .Params4bit , self .weight ) # type: ignore[has-type]
313
- self .bias = cast ( Optional [torch .nn .Parameter ], self .bias ) # type: ignore[has-type]
313
+ self .bias : Optional [torch .nn .Parameter ] = self .bias
314
314
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
315
315
# filling the device memory with float32 weights which could lead to OOM
316
316
if torch .tensor (0 , device = device ).device .type == "cuda" :
0 commit comments