Skip to content

Commit 2ffbf22

Browse files
Test suite improvements for MPS/XPU/HPU
1 parent d731fc4 commit 2ffbf22

File tree

4 files changed

+25
-15
lines changed

4 files changed

+25
-15
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ jobs:
372372
pypi_index: "https://download.pytorch.org/whl/cu128"
373373
- cuda_version: "12.9.1"
374374
torch_version: "2.8.0"
375-
pypi_index: "https://download.pytorch.org/whl/test/cu129"
375+
pypi_index: "https://download.pytorch.org/whl/cu129"
376376

377377

378378
# Linux L40S runners

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def pytest_runtest_teardown(item, nextitem):
3434
gc.collect()
3535
if torch.cuda.is_available():
3636
torch.cuda.empty_cache()
37+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
38+
torch.mps.empty_cache()
3739

3840

3941
@pytest.fixture(scope="session")

tests/test_functional.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import time
44

55
import einops
6-
import numpy as np
76
import pytest
87
import torch
98

@@ -101,16 +100,16 @@ class Test8BitBlockwiseQuantizeFunctional:
101100
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
102101
iters = 100
103102

104-
if device == "cpu":
103+
if device != "cuda":
105104
iters = 10
106105

107-
# This test is slow on CPU, so avoid atypical use cases.
106+
# This test is slow in our non-CUDA implementations, so avoid atypical use cases.
108107
if nested:
109108
pytest.skip("Not a typical use case.")
110109
if blocksize != 256:
111-
pytest.skip("Only blocksize 256 is used in CPU/XPU")
110+
pytest.skip("Only blocksize 256 is used in CPU/MPS/XPU")
112111
if dtype != torch.float32:
113-
pytest.skip("Only float32 is used in CPU/XPU")
112+
pytest.skip("Only float32 is used in CPU/MPS/XPU")
114113

115114
diffs = []
116115
reldiffs = []
@@ -239,7 +238,7 @@ def test_fp8_quant(self, device):
239238

240239
abserr = []
241240
relerr = []
242-
for i in range(100):
241+
for i in range(10):
243242
A1 = torch.randn(1024, 1024, device=device)
244243
C, SC = F.quantize_blockwise(A1, code=code)
245244
A2 = F.dequantize_blockwise(C, SC)
@@ -253,7 +252,7 @@ def test_fp8_quant(self, device):
253252

254253
abserr = []
255254
relerr = []
256-
for i in range(100):
255+
for i in range(10):
257256
A1 = torch.rand(1024, 1024, device=device)
258257
C, SC = F.quantize_blockwise(A1, code=code)
259258
A2 = F.dequantize_blockwise(C, SC)
@@ -267,7 +266,7 @@ def test_fp8_quant(self, device):
267266

268267
abserr = []
269268
relerr = []
270-
for i in range(100):
269+
for i in range(10):
271270
A1 = torch.randn(1024, 1024, device=device)
272271
C, SC = F.quantize_blockwise(A1)
273272
A2 = F.dequantize_blockwise(C, SC)
@@ -1406,28 +1405,26 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14061405
@pytest.mark.parametrize("device", get_available_devices())
14071406
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
14081407
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1409-
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
14101408
@pytest.mark.skipif(
14111409
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
1412-
reason="this test is not supported on ROCm with gfx90a architecture yet",
1410+
reason="this test is not supported on ROCm with gfx90a architectßure yet",
14131411
)
1414-
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
1412+
def test_gemv_eye_4bit(self, device, storage_type, dtype):
14151413
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
14161414
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
14171415

14181416
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
14191417
pytest.skip("This configuration is not supported on HPU.")
14201418

1421-
dims = 10
1422-
torch.random.manual_seed(np.random.randint(0, 412424242))
1419+
dims = 4
14231420
dims = get_test_dims(0, 8192, n=dims)
14241421
dims = [dim + (64 - (dim % 64)) for dim in dims]
14251422
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
14261423
for dim in dims:
14271424
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device)
14281425
B = torch.eye(dim, dtype=dtype, device=device)
14291426

1430-
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
1427+
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=False)
14311428
C3 = torch.matmul(A, B.t())
14321429
C2 = bnb.matmul_4bit(A, qB.t(), state)
14331430
A.requires_grad = True

tests/test_optim.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def rm_path(path):
172172
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
173173
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
174174
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
175+
176+
if device not in ["cuda", "xpu"]:
177+
pytest.skip("Optimizers are only supported on CUDA and XPU")
178+
175179
if optim_name.startswith("paged_") and sys.platform == "win32":
176180
pytest.skip("Paged optimizers can have issues on Windows.")
177181

@@ -253,6 +257,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
253257
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
254258
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
255259
def test_global_config(dim1, dim2, gtype, device):
260+
if device not in ["cuda", "xpu"]:
261+
pytest.skip("Optimizers are only supported on CUDA and XPU")
262+
256263
if dim1 == 1 and dim2 == 1:
257264
return
258265
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
@@ -310,6 +317,10 @@ def test_global_config(dim1, dim2, gtype, device):
310317
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
311318
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
312319
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
320+
321+
if device not in ["cuda", "xpu"]:
322+
pytest.skip("8-bit optimizers are only supported on CUDA and XPU")
323+
313324
torch.set_printoptions(precision=6)
314325

315326
if dim1 == 1 and dim2 == 1:

0 commit comments

Comments
 (0)