Skip to content

Commit de6057b

Browse files
Additional test patches for HPU
1 parent 1bc28db commit de6057b

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

tests/test_linear4bit.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,6 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
294294
if device == "cuda" and platform.system() == "Windows":
295295
pytest.skip("Triton is not officially supported on Windows")
296296

297-
if device == "hpu" and quant_type != "nf4":
298-
pytest.skip("fp4 dequantization is not supported on HPU")
299-
300297
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
301298
if (
302299
not fullgraph

tests/test_linear8bitlt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
257257
ref_output = net(x)
258258

259259
# Compile the model
260-
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
260+
compile_backend = "hpu_backend" if device == "hpu" else "inductor"
261+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
261262

262263
# Get output from compiled model
263264
with torch.no_grad():

tests/test_modules.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import nn
66

77
import bitsandbytes as bnb
8-
from tests.helpers import get_available_devices, id_formatter
8+
from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu
99

1010

1111
class MockArgs:
@@ -295,7 +295,13 @@ def test_kbit_backprop(device, module):
295295
torch.nn.init.kaiming_normal_(ref[0].weight)
296296
torch.nn.init.kaiming_normal_(ref[1].weight)
297297
ref[1].weight.requires_grad_(False)
298+
298299
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
300+
301+
if device == "hpu":
302+
if isinstance(module, bnb.nn.LinearFP4):
303+
pytest.skip("FP4 is not supported on HPU")
304+
299305
kbit[0].weight.detach().copy_(ref[0].weight)
300306
kbit[1].weight.detach().copy_(ref[1].weight)
301307
kbit[0].bias.detach().copy_(ref[0].bias)
@@ -358,6 +364,12 @@ def test_kbit_backprop(device, module):
358364
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
359365
)
360366
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
367+
if device == "hpu":
368+
if embedding_class is bnb.nn.EmbeddingFP4:
369+
pytest.skip("FP4 is not supported on HPU")
370+
elif embedding_class is bnb.nn.EmbeddingNF4 and not is_supported_on_hpu("nf4", torch.float32, quant_storage):
371+
pytest.skip("This configuration is not supported on HPU")
372+
361373
num_embeddings = 128
362374

363375
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
@@ -403,6 +415,12 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim,
403415
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
404416
)
405417
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
418+
if device == "hpu":
419+
if embedding_class is bnb.nn.EmbeddingFP4:
420+
pytest.skip("FP4 is not supported on HPU")
421+
elif embedding_class is bnb.nn.EmbeddingNF4 and not is_supported_on_hpu("nf4", torch.float32, quant_storage):
422+
pytest.skip("This configuration is not supported on HPU")
423+
406424
is_8bit = embedding_class is bnb.nn.Embedding8bit
407425

408426
num_embeddings = 128

0 commit comments

Comments
 (0)