Skip to content

Commit d02b536

Browse files
More test fixes
1 parent 0410ec1 commit d02b536

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def quantize_blockwise(
779779
state2=state2,
780780
)
781781
else:
782-
quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
782+
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
783783

784784
# TODO(matthewdouglas): Deprecate out kwarg
785785
out = out.copy_(_out) if out is not None else _out

tests/test_linear4bit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
2525
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
2626
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
27-
if device == "cpu" and quant_type == "fp4":
28-
pytest.xfail("FP4 is not supported for CPU")
27+
if device == "cpu":
28+
if quant_type == "fp4":
29+
pytest.xfail("FP4 is not supported for CPU")
30+
if quant_storage != "uint8":
31+
pytest.xfail("Only uint8 storage is supported for CPU")
2932

3033
original_dtype = torch.float16
3134
compute_dtype = None
@@ -144,8 +147,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
144147
linear_q3 = torch_load_from_buffer(bytes_4bit)
145148

146149
# Test moving to CPU and back to GPU
147-
linear_q2.to("cpu")
148-
linear_q2.to(device)
150+
if device != "cpu":
151+
linear_q2.to("cpu")
152+
linear_q2.to(device)
149153
d = linear_qs(x)
150154
assert c.dtype == d.dtype
151155
assert c.device == d.device

0 commit comments

Comments
 (0)