Skip to content

Commit 5c736a7

Browse files
skip some fp4 tests on hpu
1 parent 5dae4a8 commit 5c736a7

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

tests/test_autograd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ def test_matmul_4bit(
189189
if device == "cpu" and dtype != torch.float32 and any(req_grad) and torch.__version__ < (2, 6):
190190
pytest.xfail("mse_loss fp16 on CPU is not supported in torch < 2.6")
191191

192+
if device == "hpu" and quant_type != "nf4":
193+
pytest.skip("HPU only supports nf4")
194+
192195
for i in range(3):
193196
# normal multiply
194197
if funcs[0] in [torch.mm, torch.matmul]:

tests/test_linear4bit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
276276
if device == "cuda" and platform.system() == "Windows":
277277
pytest.skip("Triton is not officially supported on Windows")
278278

279+
if device == "hpu" and quant_type != "nf4":
280+
pytest.skip("fp4 dequantization is not supported on HPU")
281+
279282
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
280283
if (
281284
not fullgraph

tests/test_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
179179
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
180180
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
181181
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
182+
if device == "hpu" and quant_type != "nf4":
183+
pytest.skip("fp4 dequantization is not supported on HPU")
184+
182185
shape = (128, 128)
183186

184187
n = prod(shape)
@@ -210,6 +213,9 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
210213
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
211214
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
212215
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
216+
if device == "hpu" and quant_type != "nf4":
217+
pytest.skip("fp4 dequantization is not supported on HPU")
218+
213219
out_features = 1024
214220
in_features = 256
215221

0 commit comments

Comments
 (0)