Skip to content

Commit bdab075

Browse files
committed
rebase main branch
Signed-off-by: jiqing-feng <[email protected]>
2 parents 622c0ab + 39dd847 commit bdab075

File tree

12 files changed

+25
-35
lines changed

12 files changed

+25
-35
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,9 @@ def matmul(
422422
if threshold > 0.0:
423423
state.threshold = threshold
424424
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
425-
if state.is_training and A.device.type in ("cpu", "xpu"):
426-
return MatMul8bitFp.apply(A, B, out, bias, state)
427-
425+
if state.is_training:
426+
if A.device.type in ("cpu", "xpu"):
427+
return MatMul8bitFp.apply(A, B, out, bias, state)
428428
return MatMul8bitLt.apply(A, B, out, bias, state)
429429

430430

bitsandbytes/backends/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import triton.language as tl # noqa: F401
99

1010
triton_available = True
11-
except ImportError as e:
11+
except ImportError:
1212
triton_available = False
1313

1414

bitsandbytes/backends/xpu/ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ctypes as ct
33
import logging
44

5+
from packaging import version
56
import torch
67

78
from bitsandbytes.functional import _get_tensor_stream, get_ptr
@@ -12,6 +13,16 @@
1213

1314
logger = logging.getLogger(__name__)
1415

16+
# _int_mm is available in torch starting from 2.9 version
17+
if version.parse(torch.__version__).release >= version.parse("2.9"):
18+
19+
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
20+
def _(A: torch.Tensor, B: torch.Tensor):
21+
return torch._int_mm(
22+
A.reshape(-1, A.shape[-1]),
23+
B.t(),
24+
).reshape(*A.shape[:-1], B.shape[0])
25+
1526

1627
def _dequantize_4bit_impl(
1728
A: torch.Tensor,

bitsandbytes/functional.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
242242
assert e + p == total_bits - has_sign
243243
# the exponent is biased to 2^(e-1) -1 == 0
244244
evalues = []
245-
pvalues = []
246245
for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
247246
evalues.append(2**val)
248247

@@ -1357,8 +1356,6 @@ def optimizer_update_8bit_blockwise(
13571356
gnorm_scale: float = 1.0,
13581357
skip_zeros=False,
13591358
) -> None:
1360-
optim_func = None
1361-
13621359
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
13631360

13641361
torch.ops.bitsandbytes.optimizer_update_8bit_blockwise(
@@ -2089,7 +2086,7 @@ def spmm_coo(
20892086
assert cooA.values.numel() == nnz
20902087
assert cooA.cols == B.shape[0]
20912088

2092-
transposed_B = False if B.is_contiguous() else True
2089+
transposed_B = not B.is_contiguous()
20932090

20942091
ldb = B.stride()[(1 if transposed_B else 0)]
20952092
ldc = B.shape[1]
@@ -2138,12 +2135,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
21382135
assert cooA.values.numel() == nnz
21392136
assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
21402137

2141-
transposed_B = False if B.is_contiguous() else True
2142-
2143-
ldb = B.stride()[(1 if transposed_B else 0)]
2144-
ldc = B.shape[1]
2145-
2146-
values, counts = torch.unique(cooA.rowidx, return_counts=True)
2138+
_, counts = torch.unique(cooA.rowidx, return_counts=True)
21472139
offset = counts.cumsum(0).int()
21482140
max_count, max_idx = torch.sort(counts, descending=True)
21492141
max_idx = max_idx.int()
@@ -2163,11 +2155,8 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
21632155
cnnz_rows = ct.c_int32(counts.numel())
21642156
cnnz = ct.c_int32(cooA.nnz)
21652157
crowsA = ct.c_int32(cooA.rows)
2166-
ccolsA = ct.c_int32(cooA.cols)
21672158
crowsB = ct.c_int32(B.shape[1])
21682159
ccolsB = ct.c_int32(B.shape[1])
2169-
cldb = ct.c_int32(ldb)
2170-
cldc = ct.c_int32(ldc)
21712160

21722161
with _cuda_device_of(B):
21732162
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def __init__(
476476
)
477477
# self.persistent_buffers = [] # TODO consider as way to save quant state
478478
self.compute_dtype = compute_dtype
479-
self.compute_type_is_set = False if compute_dtype is None else True
479+
self.compute_type_is_set = compute_dtype is not None
480480
self.quant_state = None
481481
self.quant_storage = quant_storage
482482

@@ -1117,4 +1117,4 @@ def forward(self, x):
11171117
if self.weight.CB is not None:
11181118
self.init_8bit_state()
11191119

1120-
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
1120+
return bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias

bitsandbytes/optim/lars.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,6 @@ def step(self, closure=None):
231231
loss = closure()
232232

233233
for group in self.param_groups:
234-
params_with_grad = []
235-
d_p_list = []
236-
momentum_buffer_list = []
237234
weight_decay = group["weight_decay"]
238235
momentum = group["momentum"]
239236
dampening = group["dampening"]

bitsandbytes/optim/optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,6 @@ def step(self, closure=None):
272272
with torch.enable_grad():
273273
loss = closure()
274274

275-
overflows = []
276-
277275
if not self.initialized:
278276
self.check_overrides()
279277
self.to_gpu() # needed for fairseq pure fp16 training

bitsandbytes/research/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non
235235
# 2. Quantize B
236236
if state.has_fp16_weights:
237237
# print('B shape', B.shape)
238-
has_grad = True if (getattr(B, "grad", None) is not None) else False
238+
has_grad = getattr(B, "grad", None) is not None
239239
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
240240
if is_transposed:
241241
B = B.contiguous()

bitsandbytes/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
8484
if rdm:
8585
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
8686

87-
m = weight.mean(reduction_dim)
88-
mm = m.mean()
89-
mstd = m.std()
90-
zm = (m - mm) / mstd
91-
9287
std = weight.std(reduction_dim)
9388
stdm = std.mean()
9489
stdstd = std.std()

install_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def main():
8787

8888
# Install CUDA version(s)
8989
if version == "all":
90-
for ver in cuda_versions.keys():
90+
for ver in cuda_versions:
9191
install_cuda(ver, base_path, download_path)
9292
elif version in cuda_versions:
9393
install_cuda(version, base_path, download_path)

0 commit comments

Comments
 (0)