39
39
40
40
log = logging .getLogger (__name__ )
41
41
42
- # TODO: unpin after resolving the `quant_state` format breaking changes
43
- _BITSANDBYTES_AVAILABLE = RequirementCache ("bitsandbytes==0.41.0" )
42
+ _BITSANDBYTES_AVAILABLE = RequirementCache ("bitsandbytes>=0.42.0" )
44
43
45
44
46
45
class BitsandbytesPrecision (Precision ):
@@ -344,7 +343,7 @@ def quantize(
344
343
def to_empty (self , * , device : _DEVICE , recurse : bool = True ) -> Self :
345
344
if self .weight .dtype == torch .uint8 : # was quantized
346
345
# cannot init the quantized params directly
347
- weight = torch .empty (self .weight .quant_state [ 1 ] , device = device , dtype = torch .half )
346
+ weight = torch .empty (self .weight .quant_state . shape , device = device , dtype = torch .half )
348
347
else :
349
348
weight = torch .empty_like (self .weight .data , device = device )
350
349
device = torch .device (device )
@@ -366,7 +365,7 @@ def reset_parameters(self) -> None:
366
365
linear_init_finished = isinstance (self .weight , bnb .nn .Params4bit )
367
366
if linear_init_finished and self .weight .dtype == torch .uint8 : # was quantized
368
367
# cannot init the quantized params directly
369
- weight = torch .empty (self .weight .quant_state [ 1 ] , device = self .weight .device , dtype = torch .half )
368
+ weight = torch .empty (self .weight .quant_state . shape , device = self .weight .device , dtype = torch .half )
370
369
else :
371
370
weight = self .weight .data
372
371
torch .nn .init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
0 commit comments