11from contextlib import nullcontext
2- from io import BytesIO
32import os
43from tempfile import TemporaryDirectory
54
109from bitsandbytes import functional as F
1110from bitsandbytes .autograd import get_inverse_transform_indices , undo_layout
1211from bitsandbytes .nn .modules import Linear8bitLt
13- from tests .helpers import TRUE_FALSE , id_formatter
12+ from tests .helpers import (
13+ TRUE_FALSE ,
14+ id_formatter ,
15+ torch_load_from_buffer ,
16+ torch_save_to_buffer ,
17+ )
1418
1519# contributed by Alex Borzunov, see:
1620# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@@ -66,17 +70,6 @@ def test_linear_no_igemmlt():
6670 assert linear_custom .state .CB is not None
6771 assert linear_custom .state .CxB is None
6872
69- def torch_save_to_buffer (obj ):
70- buffer = BytesIO ()
71- torch .save (obj , buffer )
72- buffer .seek (0 )
73- return buffer
74-
75- def torch_load_from_buffer (buffer ):
76- buffer .seek (0 )
77- obj = torch .load (buffer )
78- buffer .seek (0 )
79- return obj
8073
8174@pytest .mark .parametrize ("has_fp16_weights" , TRUE_FALSE , ids = id_formatter ("has_fp16_weights" ))
8275@pytest .mark .parametrize ("serialize_before_forward" , TRUE_FALSE , ids = id_formatter ("serialize_before_forward" ))
@@ -171,4 +164,4 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
171164 assert torch .allclose (fx_first , fx_second , atol = 1e-5 )
172165 assert torch .allclose (x_first .grad , x_second .grad , atol = 1e-5 )
173166 assert torch .allclose (fx_first , fx_third , atol = 1e-5 )
174- assert torch .allclose (x_first .grad , x_third .grad , atol = 1e-5 )
167+ assert torch .allclose (x_first .grad , x_third .grad , atol = 1e-5 )
0 commit comments