Skip to content

Commit 048a2d4

Browse files
authored
Deduplicate helpers & fix lint issues from #1099 (#1107)
1 parent a1c0844 commit 048a2d4

File tree

3 files changed

+28
-34
lines changed

3 files changed

+28
-34
lines changed

tests/helpers.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from io import BytesIO
12
from itertools import product
23
import random
34
from typing import Any, List
@@ -7,6 +8,25 @@
78
test_dims_rng = random.Random(42)
89

910

11+
TRUE_FALSE = (True, False)
12+
BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool)
13+
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
14+
15+
16+
def torch_save_to_buffer(obj):
17+
buffer = BytesIO()
18+
torch.save(obj, buffer)
19+
buffer.seek(0)
20+
return buffer
21+
22+
23+
def torch_load_from_buffer(buffer):
24+
buffer.seek(0)
25+
obj = torch.load(buffer)
26+
buffer.seek(0)
27+
return obj
28+
29+
1030
def get_test_dims(min: int, max: int, *, n: int) -> List[int]:
1131
return [test_dims_rng.randint(min, max) for _ in range(n)]
1232

@@ -42,10 +62,3 @@ def id_formatter(label: str):
4262

4363
def describe_dtype(dtype: torch.dtype) -> str:
4464
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
45-
46-
47-
TRUE_FALSE = (True, False)
48-
BOOLEAN_TRIPLES = list(
49-
product(TRUE_FALSE, repeat=3)
50-
) # all combinations of (bool, bool, bool)
51-
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)

tests/test_linear4bit.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
from io import BytesIO
32
import os
43
import pickle
54
from tempfile import TemporaryDirectory
@@ -8,7 +7,7 @@
87
import torch
98

109
import bitsandbytes as bnb
11-
from tests.helpers import TRUE_FALSE
10+
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
1211

1312
storage = {
1413
"uint8": torch.uint8,
@@ -17,17 +16,6 @@
1716
"float32": torch.float32,
1817
}
1918

20-
def torch_save_to_buffer(obj):
21-
buffer = BytesIO()
22-
torch.save(obj, buffer)
23-
buffer.seek(0)
24-
return buffer
25-
26-
def torch_load_from_buffer(buffer):
27-
buffer.seek(0)
28-
obj = torch.load(buffer)
29-
buffer.seek(0)
30-
return obj
3119

3220
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
3321
@pytest.mark.parametrize("bias", TRUE_FALSE)

tests/test_linear8bitlt.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from contextlib import nullcontext
2-
from io import BytesIO
32
import os
43
from tempfile import TemporaryDirectory
54

@@ -10,7 +9,12 @@
109
from bitsandbytes import functional as F
1110
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
1211
from bitsandbytes.nn.modules import Linear8bitLt
13-
from tests.helpers import TRUE_FALSE, id_formatter
12+
from tests.helpers import (
13+
TRUE_FALSE,
14+
id_formatter,
15+
torch_load_from_buffer,
16+
torch_save_to_buffer,
17+
)
1418

1519
# contributed by Alex Borzunov, see:
1620
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@@ -66,17 +70,6 @@ def test_linear_no_igemmlt():
6670
assert linear_custom.state.CB is not None
6771
assert linear_custom.state.CxB is None
6872

69-
def torch_save_to_buffer(obj):
70-
buffer = BytesIO()
71-
torch.save(obj, buffer)
72-
buffer.seek(0)
73-
return buffer
74-
75-
def torch_load_from_buffer(buffer):
76-
buffer.seek(0)
77-
obj = torch.load(buffer)
78-
buffer.seek(0)
79-
return obj
8073

8174
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
8275
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@@ -171,4 +164,4 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
171164
assert torch.allclose(fx_first, fx_second, atol=1e-5)
172165
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
173166
assert torch.allclose(fx_first, fx_third, atol=1e-5)
174-
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
167+
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)

0 commit comments

Comments
 (0)