Skip to content

Commit fd74c06

Browse files
cleanup
1 parent a61c0fa commit fd74c06

File tree

6 files changed

+6
-339
lines changed

6 files changed

+6
-339
lines changed

bitsandbytes/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77
from . import _ops, research, utils
88
from .autograd._functions import (
99
MatmulLtState,
10-
bmm_cublas,
1110
matmul,
1211
matmul_4bit,
13-
matmul_cublas,
14-
mm_cublas,
1512
)
1613
from .backends.cpu import ops as cpu_ops
1714
from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only

bitsandbytes/autograd/_functions.py

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -106,121 +106,6 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
106106
return outputs.reshape(rows, cols).contiguous()
107107

108108

109-
@deprecated(
110-
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
111-
category=FutureWarning,
112-
)
113-
class MatMul8bit(torch.autograd.Function):
114-
@staticmethod
115-
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
116-
if precision is None:
117-
precision = [8, 8, 8]
118-
if precision[0] != 8:
119-
with torch.no_grad():
120-
output = torch.matmul(A, B)
121-
else:
122-
if len(B.shape) == 2:
123-
dim = 0
124-
else:
125-
dim = 1
126-
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
127-
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
128-
iout = F.igemm(qA, qB)
129-
output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
130-
131-
if A.requires_grad or B.requires_grad:
132-
ctx.save_for_backward(A, B)
133-
134-
ctx.quant_type = quant_type
135-
ctx.precision = precision
136-
137-
return output
138-
139-
@staticmethod
140-
def backward(ctx, grad_output):
141-
A, B = ctx.saved_tensors
142-
quant_type = ctx.quant_type
143-
precision = ctx.precision
144-
grad_A = grad_B = None
145-
146-
if B.requires_grad:
147-
if len(A.shape) == 3:
148-
dims = [0, 1]
149-
# bsi -> ibs
150-
permute_dim = [0, 2, 1]
151-
else:
152-
dims = [0]
153-
# bs -> sb
154-
permute_dim = [1, 0]
155-
156-
if precision[1] != 8:
157-
with torch.no_grad():
158-
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
159-
else:
160-
if len(B.shape) == 2 and len(A.shape) == 3:
161-
grad_output = grad_output.contiguous()
162-
if not grad_output.is_contiguous():
163-
grad_output.contiguous()
164-
qgrad_output, S1 = F.vectorwise_quant(
165-
grad_output.view(-1, grad_output.shape[2]),
166-
dim=0,
167-
quant_type=quant_type,
168-
)
169-
if not A.is_contiguous():
170-
A = A.contiguous()
171-
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
172-
igrad_B = F.igemm(qA.t(), qgrad_output)
173-
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
174-
else:
175-
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
176-
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
177-
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
178-
grad_B = F.vectorwise_mm_dequant(
179-
igrad_B,
180-
S2.permute(permute_dim),
181-
S1,
182-
grad_output.dtype,
183-
quant_type,
184-
)
185-
186-
if A.requires_grad:
187-
if len(grad_output.shape) == 3:
188-
dims = [2]
189-
else:
190-
dims = [1]
191-
192-
if len(B.shape) == 3:
193-
# bio -> boi
194-
permute_dim = [0, 2, 1]
195-
dim_B = dims
196-
else:
197-
# io -> oi
198-
permute_dim = [1, 0]
199-
dim_B = [1]
200-
201-
if precision[2] != 8:
202-
with torch.no_grad():
203-
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
204-
else:
205-
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
206-
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
207-
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
208-
grad_A = F.vectorwise_mm_dequant(
209-
igrad_A,
210-
S1,
211-
S3.permute(permute_dim),
212-
grad_output.dtype,
213-
quant_type,
214-
)
215-
216-
return grad_A, grad_B, None, None, None
217-
218-
219-
mm_cublas = MatMul8bit.apply
220-
bmm_cublas = MatMul8bit.apply
221-
matmul_cublas = MatMul8bit.apply
222-
223-
224109
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
225110
def supports_igemmlt(device: torch.device) -> bool:
226111
"""check if this device supports the optimized int8 kernel"""

bitsandbytes/functional.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,21 +1541,12 @@ def optimizer_update_8bit_blockwise(
15411541

15421542
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
15431543

1544-
print(
1545-
f"{p.device} {g.device} {state1.device} {state2.device} {qmap1.device} {qmap2.device} {absmax1.device} {absmax2.device} \n\n"
1546-
f"{p.dtype} {g.dtype} {state1.dtype} {state2.dtype} {qmap1.dtype} {qmap2.dtype} {absmax1.dtype} {absmax2.dtype} \n\n"
1547-
f"{p.__class__} {g.__class__} {state1.__class__} {state2.__class__} {qmap1.__class__} {qmap2.__class__} {absmax1.__class__} {absmax2.__class__} \n\n"
1548-
f"{p.data_ptr()} {g.data_ptr()} {state1.data_ptr()} {state2.data_ptr()} {qmap1.data_ptr()} {qmap2.data_ptr()} {absmax1.data_ptr()} {absmax2.data_ptr()} \n\n"
1549-
)
1550-
1551-
print(p, g, state1, state2)
1552-
15531544
with _cuda_device_of(g):
15541545
optim_func(
1555-
get_ptr(p.to_local()),
1556-
get_ptr(g.to_local()),
1557-
get_ptr(state1.to_local()),
1558-
get_ptr(state2.to_local()),
1546+
get_ptr(p),
1547+
get_ptr(g),
1548+
get_ptr(state1),
1549+
get_ptr(state2),
15591550
ct.c_float(beta1),
15601551
ct.c_float(beta2),
15611552
ct.c_float(beta3),
@@ -1570,7 +1561,7 @@ def optimizer_update_8bit_blockwise(
15701561
ct.c_float(weight_decay),
15711562
ct.c_float(gnorm_scale),
15721563
ct.c_bool(skip_zeros),
1573-
ct.c_int32(g.to_local().numel()),
1564+
ct.c_int32(g.numel()),
15741565
)
15751566

15761567

docs/source/installation.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126
174174
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6
175175
```
176176

177-
3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded.
177+
3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 12.6) and a different bitsandbytes library is loaded.
178178

179179
## Multi-backend Support (Alpha Release)[[multi-backend]]
180180

tests/test_deprecated.py

Lines changed: 0 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -1,201 +1,9 @@
1-
from typing import Tuple
2-
31
import numpy as np
42
import pytest
53
from scipy.stats import norm
64
import torch
75

8-
import bitsandbytes as bnb
96
from bitsandbytes import functional as F
10-
from tests.helpers import (
11-
BOOLEAN_TUPLES,
12-
describe_dtype,
13-
get_test_dims,
14-
id_formatter,
15-
)
16-
17-
18-
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
19-
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
20-
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
21-
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
22-
@pytest.mark.parametrize(
23-
"funcs",
24-
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
25-
ids=["func=bmm", "func=matmul"],
26-
)
27-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
28-
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
29-
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
30-
@pytest.mark.deprecated
31-
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
32-
if dim2 > 0:
33-
dim2 = dim2 - (dim2 % 16)
34-
dim3 = dim3 - (dim3 % 16)
35-
dim4 = dim4 - (dim4 % 16)
36-
for i in range(25):
37-
# normal multiply
38-
if funcs[0] in [torch.mm, torch.matmul]:
39-
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
40-
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
41-
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
42-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
43-
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
44-
torch.nn.init.xavier_uniform_(B)
45-
46-
if not transpose[0] and not transpose[1]:
47-
out_torch = funcs[0](A, B)
48-
out_bnb = funcs[1](A, B)
49-
elif not transpose[0] and transpose[1]:
50-
out_torch = funcs[0](A, B.t())
51-
out_bnb = funcs[1](A, B.t())
52-
elif transpose[0] and not transpose[1]:
53-
out_torch = funcs[0](A.t(), B)
54-
out_bnb = funcs[1](A.t(), B)
55-
elif transpose[0] and transpose[1]:
56-
out_torch = funcs[0](A.t(), B.t())
57-
out_bnb = funcs[1](A.t(), B.t())
58-
59-
n = out_bnb.numel()
60-
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
61-
assert (idx == 0).sum().item() < n * 0.0175
62-
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
63-
assert (idx == 0).sum().item() < n * 0.001
64-
65-
if any(req_grad):
66-
out_bnb.data.copy_(out_torch)
67-
torch.cuda.synchronize()
68-
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
69-
loss_bnb.backward()
70-
gradA1 = A.grad
71-
gradB1 = B.grad
72-
A.grad = None
73-
B.grad = None
74-
75-
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
76-
loss_torch.backward()
77-
gradA2 = A.grad
78-
gradB2 = B.grad
79-
A.grad = None
80-
B.grad = None
81-
82-
if req_grad[0]:
83-
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
84-
if req_grad[1]:
85-
n = gradB1.numel()
86-
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
87-
assert (idx == 0).sum().item() < n * 0.1
88-
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
89-
assert (idx == 0).sum().item() < n * 0.02
90-
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
91-
92-
# batched matrix multiply
93-
if funcs[0] in [torch.bmm, torch.matmul]:
94-
A = torch.randn(
95-
size=(dim1, dim2, dim3),
96-
device="cuda",
97-
requires_grad=req_grad[0],
98-
)
99-
B = torch.randn(
100-
size=(dim1, dim3, dim4),
101-
device="cuda",
102-
requires_grad=req_grad[1],
103-
)
104-
target = torch.randn(
105-
size=(dim1, dim2, dim4),
106-
device="cuda",
107-
requires_grad=req_grad[1],
108-
)
109-
torch.nn.init.xavier_uniform_(B)
110-
111-
out_torch = funcs[0](A, B)
112-
out_bnb = funcs[1](A, B)
113-
114-
n = out_bnb.numel()
115-
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
116-
assert (idx == 0).sum().item() < n * 0.01
117-
torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
118-
119-
if any(req_grad):
120-
out_bnb.data.copy_(out_torch)
121-
torch.cuda.synchronize()
122-
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
123-
loss_bnb.backward()
124-
gradA1 = A.grad
125-
gradB1 = B.grad
126-
A.grad = None
127-
B.grad = None
128-
129-
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
130-
loss_torch.backward()
131-
gradA2 = A.grad
132-
gradB2 = B.grad
133-
A.grad = None
134-
B.grad = None
135-
136-
if req_grad[0]:
137-
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
138-
if req_grad[1]:
139-
n = gradB1.numel()
140-
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
141-
assert (idx == 0).sum().item() < n * 0.1
142-
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
143-
assert (idx == 0).sum().item() < n * 0.02
144-
145-
if funcs[0] in [torch.matmul]:
146-
dim1 = dim1 - (dim1 % 16)
147-
A = torch.randn(
148-
size=(dim1, dim2, dim3),
149-
device="cuda",
150-
requires_grad=req_grad[0],
151-
)
152-
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
153-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
154-
target = torch.randn(
155-
size=(dim1, dim2, dim4),
156-
device="cuda",
157-
requires_grad=req_grad[1],
158-
)
159-
torch.nn.init.xavier_uniform_(B)
160-
161-
if transpose[1]:
162-
out_torch = funcs[0](A, B.t())
163-
out_bnb = funcs[1](A, B.t())
164-
else:
165-
out_torch = funcs[0](A, B)
166-
out_bnb = funcs[1](A, B)
167-
168-
n = out_bnb.numel()
169-
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
170-
assert (idx == 0).sum().item() < n * 0.0175
171-
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
172-
assert (idx == 0).sum().item() < n * 0.001
173-
174-
if any(req_grad):
175-
out_bnb.data.copy_(out_torch)
176-
torch.cuda.synchronize()
177-
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
178-
loss_bnb.backward()
179-
gradA1 = A.grad
180-
gradB1 = B.grad
181-
A.grad = None
182-
B.grad = None
183-
184-
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
185-
loss_torch.backward()
186-
gradA2 = A.grad
187-
gradB2 = B.grad
188-
A.grad = None
189-
B.grad = None
190-
191-
if req_grad[0]:
192-
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
193-
if req_grad[1]:
194-
n = gradB1.numel()
195-
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
196-
assert (idx == 0).sum().item() < n * 0.1
197-
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
198-
assert (idx == 0).sum().item() < n * 0.02
1997

2008

2019
@pytest.mark.deprecated

0 commit comments

Comments
 (0)