|
5 | 5 | from torch import nn |
6 | 6 |
|
7 | 7 | import bitsandbytes as bnb |
8 | | -from tests.helpers import get_available_devices, id_formatter |
| 8 | +from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class MockArgs: |
@@ -295,7 +295,13 @@ def test_kbit_backprop(device, module): |
295 | 295 | torch.nn.init.kaiming_normal_(ref[0].weight) |
296 | 296 | torch.nn.init.kaiming_normal_(ref[1].weight) |
297 | 297 | ref[1].weight.requires_grad_(False) |
| 298 | + |
298 | 299 | kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)]) |
| 300 | + |
| 301 | + if device == "hpu": |
| 302 | + if isinstance(module, bnb.nn.LinearFP4): |
| 303 | + pytest.skip("FP4 is not supported on HPU") |
| 304 | + |
299 | 305 | kbit[0].weight.detach().copy_(ref[0].weight) |
300 | 306 | kbit[1].weight.detach().copy_(ref[1].weight) |
301 | 307 | kbit[0].bias.detach().copy_(ref[0].bias) |
@@ -358,6 +364,12 @@ def test_kbit_backprop(device, module): |
358 | 364 | ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), |
359 | 365 | ) |
360 | 366 | def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage): |
| 367 | + if device == "hpu": |
| 368 | + if embedding_class is bnb.nn.EmbeddingFP4: |
| 369 | + pytest.skip("FP4 is not supported on HPU") |
| 370 | + elif embedding_class is bnb.nn.EmbeddingNF4 and not is_supported_on_hpu("nf4", torch.float32, quant_storage): |
| 371 | + pytest.skip("This configuration is not supported on HPU") |
| 372 | + |
361 | 373 | num_embeddings = 128 |
362 | 374 |
|
363 | 375 | src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to( |
@@ -403,6 +415,12 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, |
403 | 415 | ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), |
404 | 416 | ) |
405 | 417 | def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage): |
| 418 | + if device == "hpu": |
| 419 | + if embedding_class is bnb.nn.EmbeddingFP4: |
| 420 | + pytest.skip("FP4 is not supported on HPU") |
| 421 | + elif embedding_class is bnb.nn.EmbeddingNF4 and not is_supported_on_hpu("nf4", torch.float32, quant_storage): |
| 422 | + pytest.skip("This configuration is not supported on HPU") |
| 423 | + |
406 | 424 | is_8bit = embedding_class is bnb.nn.Embedding8bit |
407 | 425 |
|
408 | 426 | num_embeddings = 128 |
|
0 commit comments