Skip to content

Commit 5497111

Browse files
committed
fix test_gemv
Signed-off-by: jiqing-feng <[email protected]>
1 parent c6b714d commit 5497111

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/test_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,26 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
219219
out_features = 1024
220220
in_features = 256
221221

222+
if device == "cpu" and blocksize > in_features:
223+
pytest.skip("CPU implementation only suppoer blocksize <= in_features")
224+
222225
A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
223226
B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
224227
B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
225228
code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)
226229

230+
if device == "cpu" and bitsandbytes.functional.has_avx512bf16():
231+
state = bitsandbytes.functional.QuantState(
232+
absmax=absmax,
233+
shape=B.shape,
234+
dtype=A.dtype,
235+
blocksize=blocksize,
236+
code=code,
237+
quant_type=quant_type,
238+
)
239+
B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state)
240+
B_q = B_q.t()
241+
absmax = state.absmax
227242
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)
228243

229244
assert out.device == A.device

0 commit comments

Comments
 (0)