Skip to content

Commit 527d071

Browse files
kashifBordaawaelchlicarmocca
authored
Bump bitsandbytes minimum version (#19520)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: awaelchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent b19c3a9 commit 527d071

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

requirements/fabric/strategies.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
77
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
88
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
9-
bitsandbytes ==0.41.0 # strict
9+
bitsandbytes >=0.42.0,<0.43.0

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ hydra-core >=1.0.5, <1.4.0
88
jsonargparse[signatures] >=4.27.5, <4.28.0
99
rich >=12.3.0, <13.6.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
11-
bitsandbytes ==0.41.0 # strict
11+
bitsandbytes >=0.42.0,<0.43.0

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939

4040
log = logging.getLogger(__name__)
4141

42-
# TODO: unpin after resolving the `quant_state` format breaking changes
43-
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes==0.41.0")
42+
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.42.0")
4443

4544

4645
class BitsandbytesPrecision(Precision):
@@ -344,7 +343,7 @@ def quantize(
344343
def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
345344
if self.weight.dtype == torch.uint8: # was quantized
346345
# cannot init the quantized params directly
347-
weight = torch.empty(self.weight.quant_state[1], device=device, dtype=torch.half)
346+
weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
348347
else:
349348
weight = torch.empty_like(self.weight.data, device=device)
350349
device = torch.device(device)
@@ -366,7 +365,7 @@ def reset_parameters(self) -> None:
366365
linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
367366
if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
368367
# cannot init the quantized params directly
369-
weight = torch.empty(self.weight.quant_state[1], device=self.weight.device, dtype=torch.half)
368+
weight = torch.empty(self.weight.quant_state.shape, device=self.weight.device, dtype=torch.half)
370369
else:
371370
weight = self.weight.data
372371
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

0 commit comments

Comments
 (0)