Skip to content

Commit 9e7cdc9

Browse files
committed
Added last SwitchBack refactors. All tests green.
1 parent 008dfff commit 9e7cdc9

File tree

5 files changed

+26
-19
lines changed

5 files changed

+26
-19
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,10 @@ Improvements:
221221
Deprecated:
222222
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
223223
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
224+
225+
226+
### 0.38.1
227+
228+
Features:
229+
- Added Int8 SwitchBack layers
230+
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)

bitsandbytes/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb
6-
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
6+
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear

bitsandbytes/nn/triton_based_modules.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(
157157
bias: bool = True,
158158
device=None,
159159
dtype=None,
160-
vectorize: bool = False,
160+
vector_wise_quantization: bool = False,
161161
mem_efficient : bool = False,
162162
):
163163
super().__init__(in_features, out_features, bias, device, dtype)
@@ -167,11 +167,11 @@ def __init__(
167167
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
168168

169169
# By default, we use the global quantization.
170-
self.vectorize = vectorize
171-
if self.vectorize:
170+
self.vector_wise_quantization = vector_wise_quantization
171+
if self.vector_wise_quantization:
172172
self._fn = _switchback_vectorrize
173173
if mem_efficient:
174-
print('mem efficient is not supported for vectorize mode.')
174+
print('mem efficient is not supported for vector-wise quantization.')
175175
exit(1)
176176
else:
177177
if mem_efficient:
@@ -188,7 +188,7 @@ def prepare_for_eval(self):
188188
# m.prepare_for_eval()
189189
# model.apply(cond_prepare)
190190
print('=> preparing for eval.')
191-
if self.vectorize:
191+
if self.vector_wise_quantization:
192192
W_int8, state_W = quantize_rowwise(self.weight)
193193
else:
194194
W_int8, state_W = quantize_global(self.weight)
@@ -210,7 +210,7 @@ def forward(self, x):
210210
X = x.view(-1, x.size(-1))
211211
X_int8, state_X = quantize_rowwise(X)
212212

213-
if self.vectorize:
213+
if self.vector_wise_quantization:
214214
return int8_matmul_rowwise_dequantize(
215215
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
216216
).view(*x.size()[:-1], -1)
@@ -219,9 +219,9 @@ def forward(self, x):
219219
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
220220
).view(*x.size()[:-1], -1)
221221

222-
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
223-
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vectorize=False, mem_efficient=True)
224-
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
222+
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
223+
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
224+
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
225225

226226
# This is just the standard linear function.
227227
class StandardLinearFunction(torch.autograd.Function):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def read(fname):
1818

1919
setup(
2020
name=f"bitsandbytes",
21-
version=f"0.38.0.post2",
21+
version=f"0.38.1",
2222
author="Tim Dettmers",
2323
author_email="[email protected]",
2424
description="8-bit optimizers and matrix multiplication routines.",

tests/test_triton.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import pytest
22
import torch
33

4+
from bitsandbytes.triton.triton_utils import is_triton_available
45
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
56
from bitsandbytes.nn import Linear8bitLt
67

7-
8-
@pytest.mark.skipif(not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires a GPU with compute capability 8.0 or higher.")
9-
@pytest.mark.parametrize("vectorrize", [False, True])
10-
def test_switchback(vectorrize):
11-
for dim in [83, 17, 128]:
12-
for batch in [13, 128, 256]:
8+
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
9+
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
10+
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
11+
def test_switchback(vector_wise_quantization):
12+
for dim in [83]:
13+
for batch in [13]:
1314

1415
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
15-
print('vectorrize', vectorrize)
16-
switchback = SwitchBackLinear(dim, 4 * dim, vectorize=vectorrize).cuda().half()
16+
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
1717
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
1818
switchback.weight.data.copy_(standard.weight)
1919
switchback.bias.data.copy_(standard.bias)

0 commit comments

Comments
 (0)