Skip to content

Commit 50ee994

Browse files
authored
Merge branch 'main' into absmax
2 parents 8799041 + d9333aa commit 50ee994

File tree

4 files changed

+56
-38
lines changed

4 files changed

+56
-38
lines changed

README.md

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,25 @@ bitsandbytes has the following minimum requirements for all platforms:
2525

2626
#### Accelerator support:
2727

28+
<small>Note: this table reflects the status of the current development branch. For the latest stable release, see the
29+
[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support).
30+
</small>
31+
32+
##### Legend:
33+
🚧 = In Development,
34+
〰️ = Partially Supported,
35+
✅ = Supported,
36+
❌ = Not Supported
37+
2838
<table>
2939
<thead>
3040
<tr>
3141
<th>Platform</th>
3242
<th>Accelerator</th>
3343
<th>Hardware Requirements</th>
34-
<th>Support Status</th>
44+
<th>LLM.int8()</th>
45+
<th>QLoRA 4-bit</th>
46+
<th>8-bit Optimizers</th>
3547
</tr>
3648
</thead>
3749
<tbody>
@@ -42,13 +54,17 @@ bitsandbytes has the following minimum requirements for all platforms:
4254
<td align="right">x86-64</td>
4355
<td>◻️ CPU</td>
4456
<td>AVX2</td>
45-
<td>〰️ Partial Support</td>
57+
<td>〰️</td>
58+
<td>〰️</td>
59+
<td>❌</td>
4660
</tr>
4761
<tr>
4862
<td></td>
4963
<td>🟩 NVIDIA GPU <br><code>cuda</code></td>
5064
<td>SM50+ minimum<br>SM75+ recommended</td>
51-
<td>✅ Full Support</td>
65+
<td>✅</td>
66+
<td>✅</td>
67+
<td>✅</td>
5268
</tr>
5369
<tr>
5470
<td></td>
@@ -57,7 +73,9 @@ bitsandbytes has the following minimum requirements for all platforms:
5773
CDNA: gfx90a, gfx942<br>
5874
RDNA: gfx1100, gfx1200
5975
</td>
60-
<td>🚧 In Development</td>
76+
<td>🚧</td>
77+
<td>🚧</td>
78+
<td>🚧</td>
6179
</tr>
6280
<tr>
6381
<td></td>
@@ -67,25 +85,33 @@ bitsandbytes has the following minimum requirements for all platforms:
6785
Arc A-Series (Alchemist)<br>
6886
Arc B-Series (Battlemage)
6987
</td>
70-
<td>🚧 In Development</td>
88+
<td>🚧</td>
89+
<td>🚧</td>
90+
<td>🚧</td>
7191
</tr>
7292
<tr>
7393
<td></td>
7494
<td>🟪 Intel Gaudi <br><code>hpu</code></td>
7595
<td>Gaudi1, Gaudi2, Gaudi3</td>
76-
<td>🚧 In Development</td>
96+
<td>🚧</td>
97+
<td>🚧</td>
98+
<td>❌</td>
7799
</tr>
78100
<tr>
79101
<td align="right">aarch64</td>
80102
<td>◻️ CPU</td>
81103
<td></td>
82-
<td>〰️ Partial Support</td>
104+
<td>〰️</td>
105+
<td>〰️</td>
106+
<td>❌</td>
83107
</tr>
84108
<tr>
85109
<td></td>
86110
<td>🟩 NVIDIA GPU <br><code>cuda</code></td>
87111
<td>SM75, SM80, SM90, SM100</td>
88-
<td>✅ Full Support</td>
112+
<td>✅</td>
113+
<td>✅</td>
114+
<td>✅</td>
89115
</tr>
90116
<tr>
91117
<td colspan="4">🪟 <strong>Windows 11 / Windows Server 2019+</strong></td>
@@ -94,13 +120,17 @@ bitsandbytes has the following minimum requirements for all platforms:
94120
<td align="right">x86-64</td>
95121
<td>◻️ CPU</td>
96122
<td>AVX2</td>
97-
<td>〰️ Partial Support</td>
123+
<td>〰️</td>
124+
<td>〰️</td>
125+
<td>❌</td>
98126
</tr>
99127
<tr>
100128
<td></td>
101129
<td>🟩 NVIDIA GPU <br><code>cuda</code></td>
102130
<td>SM50+ minimum<br>SM75+ recommended</td>
103-
<td>✅ Full Support</td>
131+
<td>✅</td>
132+
<td>✅</td>
133+
<td>✅</td>
104134
</tr>
105135
<tr>
106136
<td></td>
@@ -109,7 +139,9 @@ bitsandbytes has the following minimum requirements for all platforms:
109139
Arc A-Series (Alchemist) <br>
110140
Arc B-Series (Battlemage)
111141
</td>
112-
<td>🚧 In Development</td>
142+
<td>🚧</td>
143+
<td>🚧</td>
144+
<td>🚧</td>
113145
</tr>
114146
<tr>
115147
<td colspan="4">🍎 <strong>macOS 13.1+</strong></td>
@@ -118,13 +150,17 @@ bitsandbytes has the following minimum requirements for all platforms:
118150
<td align="right">arm64</td>
119151
<td>◻️ CPU</td>
120152
<td>Apple M1+</td>
121-
<td>🚧 In Development</td>
153+
<td>🚧</td>
154+
<td>🚧</td>
155+
<td>❌</td>
122156
</tr>
123157
<tr>
124158
<td></td>
125159
<td>⬜ Metal <br><code>mps</code></td>
126160
<td>Apple M1+</td>
127-
<td>🚧 In Development</td>
161+
<td>🚧</td>
162+
<td>🚧</td>
163+
<td>❌</td>
128164
</tbody>
129165
</table>
130166

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,6 @@ def from_prequantized(
291291

292292
return self
293293

294-
@classmethod
295-
def __torch_function__(cls, func, types, args=(), kwargs=None):
296-
if kwargs is None:
297-
kwargs = {}
298-
with torch._C.DisableTorchFunctionSubclass():
299-
return func(*args, **kwargs)
300-
301294
def _quantize(self, device):
302295
w = self.data.contiguous().to(device)
303296
w_4bit, quant_state = bnb.functional.quantize_4bit(
@@ -455,14 +448,14 @@ def set_compute_type(self, x):
455448
self.compute_dtype = x.dtype
456449
elif x.dtype == torch.float16:
457450
# we take the compoute dtype passed into the layer
458-
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
451+
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
459452
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
460453
# warn the user about this
461454
warnings.warn(
462455
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
463456
)
464457
warnings.filterwarnings("ignore", message=".*inference.")
465-
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
458+
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
466459
warnings.warn(
467460
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
468461
)

tests/test_linear4bit.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,7 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
270270
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
271271
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
272272
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
273-
if device == "cpu" and quant_type == "fp4":
274-
pytest.skip("FP4 is not supported for CPU")
275-
276-
if fullgraph and torch.__version__ < (2, 8):
273+
if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
277274
pytest.skip("fullgraph mode requires torch 2.8 or higher")
278275

279276
if device == "cuda" and platform.system() == "Windows":

tests/test_modules.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
440440
dim1 = 64
441441

442442
with pytest.warns(UserWarning, match=r"inference or training"):
443-
net = nn.Sequential(
444-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
445-
)
443+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
446444
net = net.to(device)
447445
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
448446
net(inp)
449447
with pytest.warns(UserWarning, match=r"inference."):
450-
net = nn.Sequential(
451-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
452-
)
448+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
453449
net = net.to(device)
454450
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
455451
net(inp)
456452

457453
with pytest.warns(UserWarning) as record:
458-
net = nn.Sequential(
459-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
460-
)
454+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
461455
net = net.to(device)
462456
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
463457
net(inp)
464458

465-
net = nn.Sequential(
466-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
467-
)
459+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
468460
net = net.to(device)
469461
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
470462
net(inp)

0 commit comments

Comments
 (0)