Skip to content

Commit f1569d5

Browse files
Improvements for testing suite
1 parent 4fb52dc commit f1569d5

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

tests/test_functional.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ class Test8BitBlockwiseQuantizeFunctional:
9494
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
9595
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
9696
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
97+
iters = 100
98+
9799
if device == "cpu":
100+
iters = 10
101+
98102
# This test is slow on CPU, so avoid atypical use cases.
99103
if nested:
100104
pytest.skip("Not a typical use case.")
@@ -106,7 +110,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
106110

107111
diffs = []
108112
reldiffs = []
109-
for i in range(100):
113+
for i in range(iters):
110114
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
111115
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
112116
A2 = F.dequantize_blockwise(C, S)
@@ -116,15 +120,13 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
116120
reldiffs.append(reldiff.mean().item())
117121
abserr = sum(diffs) / len(diffs)
118122
relerr = sum(reldiffs) / len(reldiffs)
119-
# print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
120-
# print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
121123
assert abserr < 0.011
122124
assert relerr < 0.018
123125
assert A2.dtype == dtype
124126

125127
diffs = []
126128
code = F.create_dynamic_map(signed=signed)
127-
for i in range(100):
129+
for i in range(iters):
128130
A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
129131
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
130132
A2 = F.dequantize_blockwise(C, S)
@@ -142,29 +144,29 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
142144
assert abserr < 0.00175
143145
assert relerr < 0.012
144146
assert A2.dtype == dtype
145-
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
146-
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
147147

148-
def test_blockwise_cpu_large(self):
148+
@pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required")
149+
@pytest.mark.parametrize("hidden", [128])
150+
@pytest.mark.parametrize("blocksize", [4096, 16384])
151+
def test_blockwise_cpu_large(self, hidden, blocksize):
149152
diffs = []
150153
reldiffs = []
151154
batch = 128
152155
seq = 128
153-
for hidden in [128]: # , 14336]:
154-
for blocksize in [4096, 16384]:
155-
for i in range(2):
156-
A1 = torch.randn(batch, seq, hidden, device="cpu")
157-
t0 = time.time()
158-
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
159-
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
160-
print(time.time() - t0)
161-
diff = torch.abs(A1 - A2)
162-
reldiff = diff / torch.abs(A1 + 1e-8)
163-
diffs.append(diff.mean().item())
164-
reldiffs.append(reldiff.mean().item())
165-
assert diffs[-1] < 0.011
166-
# print(sum(diffs)/len(diffs))
167-
# print(sum(reldiffs)/len(reldiffs))
156+
157+
for i in range(2):
158+
A1 = torch.randn(batch, seq, hidden, device="cpu")
159+
t0 = time.time()
160+
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
161+
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
162+
print(time.time() - t0)
163+
diff = torch.abs(A1 - A2)
164+
reldiff = diff / torch.abs(A1 + 1e-8)
165+
diffs.append(diff.mean().item())
166+
reldiffs.append(reldiff.mean().item())
167+
assert diffs[-1] < 0.011
168+
# print(sum(diffs)/len(diffs))
169+
# print(sum(reldiffs)/len(reldiffs))
168170

169171
@pytest.mark.parametrize("device", get_available_devices())
170172
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))

tests/test_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,12 @@ class TestInt8BlockwiseQuantOps:
9797
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
9898
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
9999
def test_quantize_blockwise(self, device, dtype, blocksize):
100-
if device == "cpu" and dtype != torch.float32:
101-
pytest.skip("CPU implementation is only available for float32")
100+
if device == "cpu":
101+
if dtype != torch.float32:
102+
pytest.skip("CPU implementation is only available for float32")
103+
104+
if blocksize != 256:
105+
pytest.skip("CPU implementation is slow; only test blocksize=256")
102106

103107
code = bitsandbytes.functional.create_dynamic_map().to(device)
104108
A = torch.randn(1024, 1024, dtype=dtype, device=device)

0 commit comments

Comments
 (0)