Skip to content

Commit bdd28f2

Browse files
skip some fp4 tests on hpu
1 parent 5c736a7 commit bdd28f2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/test_functional.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,9 @@ class TestQuantize4BitFunctional:
11011101
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11021102
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
11031103
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1104+
if device == "hpu" and quant_type != "nf4":
1105+
pytest.skip("fp4 dequantization is not supported on HPU")
1106+
11041107
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11051108
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11061109
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
@@ -1133,6 +1136,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11331136
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11341137
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
11351138
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
1139+
if device == "hpu" and quant_type != "nf4":
1140+
pytest.skip("fp4 dequantization is not supported on HPU")
1141+
11361142
errs1 = []
11371143
errs2 = []
11381144
for i in range(10):
@@ -1205,6 +1211,9 @@ def test_bench_4bit_dequant(self, quant_type):
12051211
)
12061212
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
12071213
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
1214+
if device == "hpu" and storage_type != "nf4":
1215+
pytest.skip("fp4 dequantization is not supported on HPU")
1216+
12081217
errs1 = []
12091218
errs2 = []
12101219
errs3 = []
@@ -1354,6 +1363,9 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
13541363
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
13551364
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
13561365

1366+
if device == "hpu" and storage_type != "nf4":
1367+
pytest.skip("fp4 dequantization is not supported on HPU")
1368+
13571369
dims = 10
13581370
torch.random.manual_seed(np.random.randint(0, 412424242))
13591371
dims = get_test_dims(0, 8192, n=dims)

0 commit comments

Comments
 (0)