Skip to content

Commit 1bc28db

Browse files
Merge conflict
2 parents f74995b + 70bbbb9 commit 1bc28db

File tree

8 files changed

+98
-25
lines changed

8 files changed

+98
-25
lines changed

.clang-format

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
BasedOnStyle: LLVM
3+
AlignAfterOpenBracket: BlockIndent
4+
BinPackArguments: true
5+
BinPackParameters: true
6+
BracedInitializerIndentWidth: 4
7+
ColumnLimit: 120
8+
Cpp11BracedListStyle: true
9+
IndentWidth: 4
10+
IndentWrappedFunctionNames: true
11+
PointerAlignment: Left
12+
SeparateDefinitionBlocks: Always
13+
Standard: c++17
14+
StatementMacros:
15+
- 'MAKE_PreconditionOptimizer32bit1State'
16+
- 'MAKE_PreconditionOptimizer32bit2State'
17+
- 'MAKE_PreconditionStatic8bit1State'
18+
- 'MAKE_PreconditionStatic8bit2State'
19+
- 'MAKE_Optimizer32bit1State'
20+
- 'MAKE_optimizerStatic8bit1State'
21+
- 'MAKE_optimizerStatic8bit2State'
22+
- 'MAKE_OptimizerStatic8bit1StateBlockwise'
23+
- 'MAKE_OptimizerStatic8bit2StateBlockwise'
24+
- 'MAKE_kQuantizeBlockwise'
25+
- 'MAKE_BLOCKWISE8'
26+
- 'MAKE_ELEMENTWISE_FUNC'
27+
- 'CMAKE_ELEMENTWISE_FUNC'
28+
- 'MAKE_FUNC8'
29+
- 'MAKE_FUNC32'
30+
- 'MAKE_CBLOCKWISE8'
31+
- 'MAKE_CFUNC8'
32+
- 'MAKE_CFUNC32'
33+
34+
UseTab: Never
35+
36+
...

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@ repos:
2121
rev: v1.26.0
2222
hooks:
2323
- id: typos
24+
- repo: https://github.com/pre-commit/mirrors-clang-format
25+
rev: v20.1.6
26+
hooks:
27+
- id: clang-format
28+
types_or: [c++, c, cuda]
29+
files: ^csrc/

bitsandbytes/backends/hpu/ops.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ def _(
2929
if A.dtype != torch.uint8:
3030
A = A.view(torch.uint8)
3131

32-
transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True
33-
3432
A = A.reshape(-1)
3533

3634
if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22):
@@ -47,7 +45,4 @@ def _(
4745

4846
output = out_dq.reshape(shape)
4947

50-
if transpose:
51-
output = output.t()
52-
5348
return output

tests/helpers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,14 @@ def id_formatter(label: str):
9898

9999
def describe_dtype(dtype: torch.dtype) -> str:
100100
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
101+
102+
103+
def is_supported_on_hpu(
104+
quant_type: str = "nf4", dtype: torch.dtype = torch.bfloat16, quant_storage: torch.dtype = torch.uint8
105+
) -> bool:
106+
"""
107+
Check if the given quant_type, dtype and quant_storage are supported on HPU.
108+
"""
109+
if quant_type == "fp4" or dtype == torch.float16 or quant_storage not in (torch.uint8, torch.bfloat16):
110+
return False
111+
return True

tests/test_autograd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
describe_dtype,
99
get_available_devices,
1010
id_formatter,
11+
is_supported_on_hpu,
1112
)
1213

1314
TRANSPOSE_VALS = [(False, True), (False, False)]
@@ -189,8 +190,8 @@ def test_matmul_4bit(
189190
if device == "cpu" and dtype != torch.float32 and any(req_grad) and torch.__version__ < (2, 6):
190191
pytest.xfail("mse_loss fp16 on CPU is not supported in torch < 2.6")
191192

192-
if device == "hpu" and quant_type != "nf4":
193-
pytest.skip("HPU only supports nf4")
193+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
194+
pytest.skip("This configuration is not supported on HPU.")
194195

195196
for i in range(3):
196197
# normal multiply

tests/test_functional.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_available_devices,
1717
get_test_dims,
1818
id_formatter,
19+
is_supported_on_hpu,
1920
)
2021

2122
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@@ -1101,8 +1102,8 @@ class TestQuantize4BitFunctional:
11011102
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11021103
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
11031104
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1104-
if device == "hpu" and quant_type != "nf4":
1105-
pytest.skip("fp4 dequantization is not supported on HPU")
1105+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
1106+
pytest.skip("This configuration is not supported on HPU.")
11061107

11071108
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11081109
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
@@ -1135,14 +1136,15 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11351136
@pytest.mark.parametrize("device", get_available_devices())
11361137
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11371138
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
1138-
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
1139-
if device == "hpu" and quant_type != "nf4":
1140-
pytest.skip("fp4 dequantization is not supported on HPU")
1139+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
1140+
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
1141+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
1142+
pytest.skip("FP4 quantization is not supported on HPU.")
11411143

11421144
errs1 = []
11431145
errs2 = []
11441146
for i in range(10):
1145-
A1 = torch.randn(1024, 1024, device=device).half()
1147+
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11461148
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11471149
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
11481150
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
@@ -1211,8 +1213,8 @@ def test_bench_4bit_dequant(self, quant_type):
12111213
)
12121214
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
12131215
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
1214-
if device == "hpu":
1215-
pytest.skip("gemv not supported on HPU")
1216+
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype, quant_storage):
1217+
pytest.skip("This configuration is not supported on HPU.")
12161218

12171219
errs1 = []
12181220
errs2 = []
@@ -1363,8 +1365,8 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
13631365
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
13641366
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
13651367

1366-
if device == "hpu" and storage_type != "nf4":
1367-
pytest.skip("fp4 dequantization is not supported on HPU")
1368+
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
1369+
pytest.skip("This configuration is not supported on HPU.")
13681370

13691371
dims = 10
13701372
torch.random.manual_seed(np.random.randint(0, 412424242))

tests/test_linear4bit.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
describe_dtype,
1414
get_available_devices,
1515
id_formatter,
16+
is_supported_on_hpu,
1617
torch_load_from_buffer,
1718
torch_save_to_buffer,
1819
)
@@ -27,12 +28,17 @@
2728

2829
@pytest.mark.parametrize("device", get_available_devices())
2930
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
31+
@pytest.mark.parametrize("original_dtype", [torch.float16, torch.bfloat16])
3032
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
3133
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
3234
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
3335
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
34-
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
35-
original_dtype = torch.float16
36+
def test_linear_serialization(
37+
device, quant_type, original_dtype, compress_statistics, bias, quant_storage, save_before_forward
38+
):
39+
if device == "hpu" and not is_supported_on_hpu(quant_type, original_dtype, storage[quant_storage]):
40+
pytest.skip("This configuration is not supported on HPU.")
41+
3642
compute_dtype = None
3743
layer_shape = (300, 400)
3844

@@ -188,6 +194,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
188194
@pytest.mark.parametrize("blocksize", [64, 128])
189195
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
190196
def test_copy_param(device, quant_type, blocksize, compress_statistics):
197+
if device == "hpu" and not is_supported_on_hpu(quant_type):
198+
pytest.skip("This configuration is not supported on HPU.")
199+
191200
tensor = torch.randn(300, 400)
192201
param = bnb.nn.Params4bit(
193202
data=tensor,
@@ -207,6 +216,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
207216
@pytest.mark.parametrize("blocksize", [64, 128])
208217
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
209218
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
219+
if device == "hpu" and not is_supported_on_hpu(quant_type):
220+
pytest.skip("This configuration is not supported on HPU.")
221+
210222
tensor = torch.randn(300, 400)
211223
param = bnb.nn.Params4bit(
212224
data=tensor,
@@ -233,6 +245,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
233245
@pytest.mark.parametrize("blocksize", [64, 128])
234246
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
235247
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
248+
if device == "hpu" and not is_supported_on_hpu(quant_type):
249+
pytest.skip("This configuration is not supported on HPU.")
250+
236251
original_tensor = torch.randn(300, 400)
237252
original_param = bnb.nn.Params4bit(
238253
data=original_tensor,
@@ -270,6 +285,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
270285
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
271286
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
272287
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
288+
if device == "hpu" and not is_supported_on_hpu(quant_type):
289+
pytest.skip("This configuration is not supported on HPU.")
290+
273291
if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
274292
pytest.skip("fullgraph mode requires torch 2.8 or higher")
275293

@@ -317,7 +335,8 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
317335
ref_output = net(x)
318336

319337
# Compile the model
320-
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
338+
compile_backend = "hpu_backend" if device == "hpu" else "inductor"
339+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
321340

322341
# Get output from compiled model
323342
with torch.no_grad():

tests/test_ops.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import bitsandbytes
77
from bitsandbytes.functional import ipex_xpu
8-
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
8+
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
99

1010
# torch.library.opcheck is only available in torch 2.4 and later.
1111
# When testing with older versions, we will skip it as a no-op.
@@ -158,6 +158,9 @@ class Test4bitBlockwiseQuantOps:
158158
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
159159
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
160160
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
161+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
162+
pytest.skip("This configuration is not supported on HPU.")
163+
161164
A = torch.randn(1024, 1024, dtype=dtype, device=device)
162165

163166
out, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, storage_dtype)
@@ -179,8 +182,8 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
179182
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
180183
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
181184
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
182-
if device == "hpu" and quant_type != "nf4":
183-
pytest.skip("fp4 dequantization is not supported on HPU")
185+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
186+
pytest.skip("This configuration is not supported on HPU.")
184187

185188
shape = (128, 128)
186189

@@ -213,8 +216,8 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
213216
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
214217
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
215218
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
216-
if device == "hpu":
217-
pytest.skip("gemv not supported on HPU")
219+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
220+
pytest.skip("This configuration is not supported on HPU.")
218221

219222
out_features = 1024
220223
in_features = 256

0 commit comments

Comments
 (0)