Skip to content

Commit 70bbbb9

Browse files
authored
HPU support for unit tests (#1680)
1 parent d863adb commit 70bbbb9

File tree

6 files changed

+63
-11
lines changed

6 files changed

+63
-11
lines changed

bitsandbytes/backends/hpu/ops.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ def _(
2929
if A.dtype != torch.uint8:
3030
A = A.view(torch.uint8)
3131

32-
transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True
33-
3432
A = A.reshape(-1)
3533

3634
if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22):
@@ -47,7 +45,4 @@ def _(
4745

4846
output = out_dq.reshape(shape)
4947

50-
if transpose:
51-
output = output.t()
52-
5348
return output

tests/helpers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,14 @@ def id_formatter(label: str):
9898

9999
def describe_dtype(dtype: torch.dtype) -> str:
100100
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
101+
102+
103+
def is_supported_on_hpu(
104+
quant_type: str = "nf4", dtype: torch.dtype = torch.bfloat16, quant_storage: torch.dtype = torch.uint8
105+
) -> bool:
106+
"""
107+
Check if the given quant_type, dtype and quant_storage are supported on HPU.
108+
"""
109+
if quant_type == "fp4" or dtype == torch.float16 or quant_storage not in (torch.uint8, torch.bfloat16):
110+
return False
111+
return True

tests/test_autograd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
describe_dtype,
99
get_available_devices,
1010
id_formatter,
11+
is_supported_on_hpu,
1112
)
1213

1314
TRANSPOSE_VALS = [(False, True), (False, False)]
@@ -189,6 +190,9 @@ def test_matmul_4bit(
189190
if device == "cpu" and dtype != torch.float32 and any(req_grad) and torch.__version__ < (2, 6):
190191
pytest.xfail("mse_loss fp16 on CPU is not supported in torch < 2.6")
191192

193+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
194+
pytest.skip("This configuration is not supported on HPU.")
195+
192196
for i in range(3):
193197
# normal multiply
194198
if funcs[0] in [torch.mm, torch.matmul]:

tests/test_functional.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_available_devices,
1717
get_test_dims,
1818
id_formatter,
19+
is_supported_on_hpu,
1920
)
2021

2122
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@@ -1101,6 +1102,9 @@ class TestQuantize4BitFunctional:
11011102
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11021103
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
11031104
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1105+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
1106+
pytest.skip("This configuration is not supported on HPU.")
1107+
11041108
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11051109
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11061110
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
@@ -1132,11 +1136,15 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11321136
@pytest.mark.parametrize("device", get_available_devices())
11331137
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11341138
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
1135-
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
1139+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
1140+
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
1141+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
1142+
pytest.skip("FP4 quantization is not supported on HPU.")
1143+
11361144
errs1 = []
11371145
errs2 = []
11381146
for i in range(10):
1139-
A1 = torch.randn(1024, 1024, device=device).half()
1147+
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11401148
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11411149
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
11421150
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
@@ -1205,6 +1213,9 @@ def test_bench_4bit_dequant(self, quant_type):
12051213
)
12061214
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
12071215
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
1216+
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype, quant_storage):
1217+
pytest.skip("This configuration is not supported on HPU.")
1218+
12081219
errs1 = []
12091220
errs2 = []
12101221
errs3 = []
@@ -1354,6 +1365,9 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
13541365
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
13551366
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
13561367

1368+
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
1369+
pytest.skip("This configuration is not supported on HPU.")
1370+
13571371
dims = 10
13581372
torch.random.manual_seed(np.random.randint(0, 412424242))
13591373
dims = get_test_dims(0, 8192, n=dims)

tests/test_linear4bit.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
describe_dtype,
1414
get_available_devices,
1515
id_formatter,
16+
is_supported_on_hpu,
1617
torch_load_from_buffer,
1718
torch_save_to_buffer,
1819
)
@@ -27,12 +28,17 @@
2728

2829
@pytest.mark.parametrize("device", get_available_devices())
2930
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
31+
@pytest.mark.parametrize("original_dtype", [torch.float16, torch.bfloat16])
3032
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
3133
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
3234
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
3335
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
34-
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
35-
original_dtype = torch.float16
36+
def test_linear_serialization(
37+
device, quant_type, original_dtype, compress_statistics, bias, quant_storage, save_before_forward
38+
):
39+
if device == "hpu" and not is_supported_on_hpu(quant_type, original_dtype, storage[quant_storage]):
40+
pytest.skip("This configuration is not supported on HPU.")
41+
3642
compute_dtype = None
3743
layer_shape = (300, 400)
3844

@@ -188,6 +194,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
188194
@pytest.mark.parametrize("blocksize", [64, 128])
189195
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
190196
def test_copy_param(device, quant_type, blocksize, compress_statistics):
197+
if device == "hpu" and not is_supported_on_hpu(quant_type):
198+
pytest.skip("This configuration is not supported on HPU.")
199+
191200
tensor = torch.randn(300, 400)
192201
param = bnb.nn.Params4bit(
193202
data=tensor,
@@ -207,6 +216,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
207216
@pytest.mark.parametrize("blocksize", [64, 128])
208217
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
209218
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
219+
if device == "hpu" and not is_supported_on_hpu(quant_type):
220+
pytest.skip("This configuration is not supported on HPU.")
221+
210222
tensor = torch.randn(300, 400)
211223
param = bnb.nn.Params4bit(
212224
data=tensor,
@@ -233,6 +245,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
233245
@pytest.mark.parametrize("blocksize", [64, 128])
234246
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
235247
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
248+
if device == "hpu" and not is_supported_on_hpu(quant_type):
249+
pytest.skip("This configuration is not supported on HPU.")
250+
236251
original_tensor = torch.randn(300, 400)
237252
original_param = bnb.nn.Params4bit(
238253
data=original_tensor,
@@ -270,6 +285,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
270285
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
271286
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
272287
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
288+
if device == "hpu" and not is_supported_on_hpu(quant_type):
289+
pytest.skip("This configuration is not supported on HPU.")
290+
273291
if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
274292
pytest.skip("fullgraph mode requires torch 2.8 or higher")
275293

@@ -314,7 +332,8 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
314332
ref_output = net(x)
315333

316334
# Compile the model
317-
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
335+
compile_backend = "hpu_backend" if device == "hpu" else "inductor"
336+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
318337

319338
# Get output from compiled model
320339
with torch.no_grad():

tests/test_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import bitsandbytes
77
from bitsandbytes.functional import ipex_xpu
8-
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
8+
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
99

1010
# torch.library.opcheck is only available in torch 2.4 and later.
1111
# When testing with older versions, we will skip it as a no-op.
@@ -158,6 +158,9 @@ class Test4bitBlockwiseQuantOps:
158158
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
159159
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
160160
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
161+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
162+
pytest.skip("This configuration is not supported on HPU.")
163+
161164
A = torch.randn(1024, 1024, dtype=dtype, device=device)
162165

163166
out, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, storage_dtype)
@@ -179,6 +182,9 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
179182
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
180183
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
181184
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
185+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
186+
pytest.skip("This configuration is not supported on HPU.")
187+
182188
shape = (128, 128)
183189

184190
n = prod(shape)
@@ -210,6 +216,9 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
210216
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
211217
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
212218
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
219+
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
220+
pytest.skip("This configuration is not supported on HPU.")
221+
213222
out_features = 1024
214223
in_features = 256
215224

0 commit comments

Comments
 (0)