Skip to content

Commit 6f6c5e1

Browse files
authored
Merge branch 'main' into main
2 parents 39861c2 + 63f538a commit 6f6c5e1

29 files changed

+97
-559
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ jobs:
103103
matrix:
104104
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15]
105105
# Test with the oldest supported torch version, the newest two stable/RC.
106-
torch_version: ["2.3.1", "2.7.1", "2.8.0"]
106+
torch_version: ["2.3.1", "2.8.0", "2.9.0"]
107107
include:
108108
- os: ubuntu-22.04
109109
arch: x86_64

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.11.2
3+
rev: v0.14.3
44
hooks:
55
- id: ruff
66
args:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The library includes quantization primitives for 8-bit & 4-bit operations, throu
1919
## System Requirements
2020
bitsandbytes has the following minimum requirements for all platforms:
2121

22-
* Python 3.9+
22+
* Python 3.10+
2323
* [PyTorch](https://pytorch.org/get-started/locally/) 2.3+
2424
* _Note: While we aim to provide wide backwards compatibility, we recommend using the latest version of PyTorch for the best experience._
2525

benchmarking/matmul_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def test_bench_matmul(batch, seq, model, hidden):
3535
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
3636
torch.nn.init.xavier_uniform_(B)
3737

38-
B_fp4, state = F.quantize_fp4(B)
39-
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
38+
_B_fp4, _state = F.quantize_fp4(B)
39+
_B_fp4_c, _state_c = F.quantize_fp4(B, compress_statistics=True)
4040

4141
B_nf4, state_nf4 = F.quantize_nf4(B)
4242
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
@@ -117,8 +117,8 @@ def test_bench_matmul(batch, seq, model, hidden):
117117
f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
118118
)
119119

120-
CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0)
121-
CB, SCB, _ = F.int8_vectorwise_quant(B)
120+
CA, _SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0)
121+
CB, _SCB, _ = F.int8_vectorwise_quant(B)
122122
torch.cuda.synchronize()
123123
t0 = time.time()
124124
for i in range(iters):

bitsandbytes/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ def _import_backends():
5454
"""
5555
from importlib.metadata import entry_points
5656

57-
if sys.version_info < (3, 10):
58-
extensions = entry_points().get("bitsandbytes.backends", [])
59-
else:
60-
extensions = entry_points(group="bitsandbytes.backends")
57+
extensions = entry_points(group="bitsandbytes.backends")
6158

6259
for ext in extensions:
6360
try:
@@ -75,4 +72,4 @@ def _import_backends():
7572
"optim.optimizer.MockArgs": False,
7673
}
7774

78-
__version__ = "0.48.3.dev0"
75+
__version__ = "0.49.0.dev0"

bitsandbytes/autograd/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from ._functions import get_inverse_transform_indices, undo_layout

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from dataclasses import dataclass
22
from math import prod
3-
from typing import Callable, Optional
3+
from typing import Optional
44
import warnings
55
from warnings import warn
66

77
import torch
8-
from typing_extensions import deprecated
98

109
import bitsandbytes.functional as F
1110

@@ -49,66 +48,9 @@ def get_current_outlier_idx(self):
4948
return torch.Tensor(list(self.outliers)).to(torch.int64)
5049

5150

52-
@deprecated(
53-
"This function is deprecated and will be removed in a future release.",
54-
category=FutureWarning,
55-
)
56-
def get_inverse_transform_indices(
57-
transform_tile: Callable[[torch.Tensor], torch.Tensor],
58-
tile_size: tuple[int, int],
59-
):
60-
"""
61-
Compute a permutation of indices that invert the specified (tiled) matrix transformation
62-
63-
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
64-
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
65-
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
66-
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
67-
:returns: indices
68-
"""
69-
d1, d2 = tile_size
70-
assert 0 < d1 * d2 < 2**64
71-
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
72-
# encode each position in tile as a tuple of <= 8 unique bytes
73-
permuted_tile_indices = torch.zeros_like(tile_indices)
74-
for i in range(8):
75-
# select i-th byte, apply transformation and trace where each index ended up
76-
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
77-
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
78-
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
79-
permuted_tile_i = transform_tile(sample_tile_i)
80-
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
81-
permuted_tile_indices += ith_permuted_indices * (256**i)
82-
if d1 * d2 < 256**i:
83-
break # if all indices fit in i bytes, stop early
84-
return permuted_tile_indices
85-
86-
8751
_is_compiling = torch.compiler.is_compiling
8852

8953

90-
@deprecated(
91-
"This function is deprecated and will be removed in a future release.",
92-
category=FutureWarning,
93-
)
94-
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
95-
"""
96-
Undo a tiled permutation such as turing or ampere layout
97-
98-
:param permuted_tensor: torch tensor in a permuted layout
99-
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
100-
:return: contiguous row-major tensor
101-
"""
102-
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
103-
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
104-
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
105-
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
106-
outputs[tile_indices.flatten()] = tensor
107-
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
108-
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
109-
return outputs.reshape(rows, cols).contiguous()
110-
111-
11254
@dataclass
11355
class MatmulLtState:
11456
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
@@ -257,7 +199,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
257199
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
258200

259201
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
260-
CAt, subA, A = ctx.tensors
202+
CAt, subA, _A = ctx.tensors
261203
SCAt, idx = ctx.tensor_states
262204
state: MatmulLtState = ctx.state
263205
grad_A = grad_B = grad_bias = None

bitsandbytes/backends/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import torch
55

66
try:
7-
import triton # noqa: F401
87
import triton.language as tl # noqa: F401
98

9+
import triton # noqa: F401
10+
1011
triton_available = True
1112
except ImportError:
1213
triton_available = False

bitsandbytes/functional.py

Lines changed: 3 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ctypes as ct
77
import itertools
88
from math import prod
9-
from typing import Any, Optional, Union
9+
from typing import Any, Optional
1010

1111
import numpy as np
1212
import torch
@@ -1413,7 +1413,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
14131413
raise ValueError(f"Gradient type {grad.dtype} not supported!")
14141414

14151415
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
1416-
vals, idx = torch.sort(gnorm_vec)
1416+
vals, _ = torch.sort(gnorm_vec)
14171417
clip_value = torch.sqrt(vals[percentile])
14181418
gnorm_scale = 1.0
14191419

@@ -1795,102 +1795,6 @@ def int8_mm_dequant(
17951795
return result
17961796

17971797

1798-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
1799-
def get_colrow_absmax(
1800-
A: torch.Tensor,
1801-
row_stats: Optional[torch.Tensor] = None,
1802-
col_stats: Optional[torch.Tensor] = None,
1803-
nnz_block_ptr: Optional[torch.Tensor] = None,
1804-
threshold=0.0,
1805-
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
1806-
""" "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
1807-
1808-
The row-wise and column-wise absmax values are determined.
1809-
1810-
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
1811-
1812-
<Tip>
1813-
This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
1814-
The column-wise quantization scales are not typically needed in inference scenarios.
1815-
</Tip>
1816-
1817-
Args:
1818-
A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
1819-
row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
1820-
col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
1821-
nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
1822-
threshold (`float`, *optional*):
1823-
An optional threshold for sparse decomposition of outlier features.
1824-
No outliers are held back when 0.0. Defaults to 0.0.
1825-
1826-
Returns:
1827-
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
1828-
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
1829-
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
1830-
- `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
1831-
"""
1832-
assert A.is_floating_point()
1833-
1834-
outlier_mask = None
1835-
1836-
if row_stats is None or col_stats is None:
1837-
absA = A.abs().view(-1, A.shape[-1])
1838-
1839-
if threshold > 0.0:
1840-
# Filter outliers from stats when enabled
1841-
outlier_mask = absA >= threshold
1842-
absA.masked_fill_(outlier_mask, 0.0)
1843-
1844-
if row_stats is None:
1845-
# shape [rows]; unsqueeze(-1) gives [rows,1]
1846-
# We have a CUDA kernel for row max, but not yet for cols.
1847-
row_stats = get_row_absmax(A, threshold)
1848-
1849-
if col_stats is None:
1850-
# shape [cols]; unsqueeze(0) gives [1,cols]
1851-
col_stats = absA.amax(dim=0, keepdim=False).float()
1852-
1853-
return row_stats, col_stats, outlier_mask
1854-
1855-
1856-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
1857-
def get_row_absmax(A: torch.Tensor, threshold=0.0):
1858-
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
1859-
1860-
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
1861-
1862-
Args:
1863-
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
1864-
threshold (`float`, *optional*):
1865-
An optional threshold for sparse decomposition of outlier features.
1866-
No outliers are held back when 0.0. Defaults to 0.0.
1867-
1868-
Returns:
1869-
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
1870-
"""
1871-
1872-
assert A.dtype == torch.float16
1873-
1874-
rows = prod(A.shape[:-1])
1875-
cols = A.shape[-1]
1876-
1877-
row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device)
1878-
1879-
is_on_gpu([A])
1880-
1881-
with _cuda_device_of(A):
1882-
lib.cget_row_stats(
1883-
get_ptr(A),
1884-
get_ptr(row_stats),
1885-
ct.c_float(threshold),
1886-
ct.c_int32(rows),
1887-
ct.c_int32(cols),
1888-
_get_tensor_stream(A),
1889-
)
1890-
1891-
return row_stats
1892-
1893-
18941798
class COOSparseTensor:
18951799
def __init__(
18961800
self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor
@@ -2059,7 +1963,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
20591963

20601964

20611965
def spmm_coo(
2062-
cooA: Union[COOSparseTensor, torch.Tensor],
1966+
cooA: COOSparseTensor | torch.Tensor,
20631967
B: torch.Tensor,
20641968
out: Optional[torch.Tensor] = None,
20651969
):

bitsandbytes/nn/modules.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -310,28 +310,28 @@ def _quantize(self, device):
310310
def cpu(self):
311311
return self.to(device="cpu")
312312

313-
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
313+
def cuda(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
314314
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
315315

316-
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
316+
def xpu(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
317317
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
318318

319319
@overload
320320
def to(
321321
self: T,
322-
device: Optional[Union[int, device]] = ...,
323-
dtype: Optional[Union[dtype, str]] = ...,
322+
device: Optional[int | device] = ...,
323+
dtype: Optional[dtype | str] = ...,
324324
non_blocking: bool = ...,
325325
) -> T: ...
326326

327327
@overload
328-
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
328+
def to(self: T, dtype: dtype | str, non_blocking: bool = ...) -> T: ...
329329

330330
@overload
331331
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
332332

333333
def to(self, *args, **kwargs):
334-
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
334+
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
335335

336336
if device is not None and device.type != "meta" and not self.bnb_quantized:
337337
return self._quantize(device)
@@ -644,10 +644,10 @@ def _quantize(self, device):
644644
def cpu(self):
645645
return self.to(device="cpu")
646646

647-
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
647+
def cuda(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
648648
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
649649

650-
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
650+
def xpu(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
651651
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
652652

653653
def __deepcopy__(self, memo):
@@ -665,19 +665,19 @@ def __deepcopy__(self, memo):
665665
@overload
666666
def to(
667667
self: T,
668-
device: Optional[Union[int, device]] = ...,
669-
dtype: Optional[Union[dtype, str]] = ...,
668+
device: Optional[int | device] = ...,
669+
dtype: Optional[dtype | str] = ...,
670670
non_blocking: bool = ...,
671671
) -> T: ...
672672

673673
@overload
674-
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
674+
def to(self: T, dtype: dtype | str, non_blocking: bool = ...) -> T: ...
675675

676676
@overload
677677
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
678678

679679
def to(self, *args, **kwargs):
680-
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
680+
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
681681

682682
is_quantized = self.data.dtype == torch.int8
683683

@@ -1048,7 +1048,7 @@ def to(self, *args, **kwargs):
10481048
# Call the parent to() method to handle standard parameter/buffer movement
10491049
result = super().to(*args, **kwargs)
10501050

1051-
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
1051+
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
10521052

10531053
# Handle state tensors if needed.
10541054
if device is not None:

0 commit comments

Comments
 (0)