1- import bitsandbytes as bnb
1+ import os
2+ from contextlib import nullcontext
3+ from itertools import product
4+ from tempfile import TemporaryDirectory
5+
26import pytest
37import torch
4- from bitsandbytes import functional as F
58
9+ import bitsandbytes as bnb
10+ from bitsandbytes import functional as F
611from bitsandbytes .autograd import get_inverse_transform_indices , undo_layout
712from bitsandbytes .nn .modules import Linear8bitLt
813
14+
915# contributed by Alex Borzunov, see:
1016# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
1117
@@ -26,6 +32,7 @@ def test_layout_exact_match():
2632 assert restored_x .is_contiguous ()
2733 assert torch .all (torch .eq (restored_x , x ))
2834
35+
2936@pytest .mark .skipif (not torch .cuda .is_available (), reason = "this test requires a GPU" )
3037def test_linear_no_igemmlt ():
3138 linear = torch .nn .Linear (1024 , 3072 )
@@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
4350 linear .weight .data .clone (), requires_grad = False , has_fp16_weights = False
4451 ).to (linear .weight .dtype )
4552 linear_custom .bias = linear .bias
46- linear = linear_custom .cuda ()
53+ linear_custom = linear_custom .cuda ()
4754 linear = linear .half ().cuda ()
4855
4956 x_ref = x .clone ().cuda ().requires_grad_ (True )
@@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
5966 assert not linear_custom .state .has_fp16_weights
6067 assert linear_custom .state .CB is not None
6168 assert linear_custom .state .CxB is None
69+
70+
71+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "this test requires a GPU" )
72+ @pytest .mark .parametrize ("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt" ,
73+ list (product ([False , True ], [False , True ], [False , True ], [False , True ])))
74+ def test_linear_serialization (has_fp16_weights , serialize_before_forward , deserialize_before_cuda , force_no_igemmlt ):
75+ linear = torch .nn .Linear (32 , 96 )
76+ x = torch .randn (3 , 32 , dtype = torch .half )
77+
78+ linear_custom = Linear8bitLt (
79+ linear .in_features ,
80+ linear .out_features ,
81+ linear .bias is not None ,
82+ has_fp16_weights = has_fp16_weights ,
83+ threshold = 6.0 ,
84+ )
85+ if force_no_igemmlt :
86+ linear_custom .state .force_no_igemmlt = True
87+
88+ linear_custom .weight = bnb .nn .Int8Params (
89+ linear .weight .data .clone (), requires_grad = has_fp16_weights , has_fp16_weights = has_fp16_weights
90+ )
91+ linear_custom .bias = linear .bias
92+ linear_custom = linear_custom .cuda ()
93+
94+ if serialize_before_forward :
95+ state_dict_8bit = linear_custom .state_dict ()
96+
97+ x_first = x .clone ().cuda ().requires_grad_ (True )
98+ fx_first = linear_custom (x_first ).float ()
99+ grad_proj = torch .randn_like (fx_first )
100+ (fx_first * grad_proj ).mean ().backward ()
101+
102+ if not serialize_before_forward :
103+ state_dict_8bit = linear_custom .state_dict ()
104+
105+ with TemporaryDirectory () as tmpdir :
106+ state_path_8bit = os .path .join (tmpdir , "state_8bit.pth" )
107+ state_path = os .path .join (tmpdir , "state.pth" )
108+
109+ torch .save (linear .state_dict (), state_path )
110+ torch .save (state_dict_8bit , state_path_8bit )
111+
112+ if not has_fp16_weights :
113+ assert os .path .getsize (state_path_8bit ) < 0.5 * os .path .getsize (state_path )
114+
115+ new_state_dict = torch .load (state_path_8bit )
116+
117+ new_linear_custom = Linear8bitLt (
118+ linear .in_features ,
119+ linear .out_features ,
120+ linear .bias is not None ,
121+ has_fp16_weights = has_fp16_weights ,
122+ threshold = 6.0 ,
123+ )
124+ if force_no_igemmlt :
125+ new_linear_custom .state .force_no_igemmlt = True
126+
127+ if deserialize_before_cuda :
128+ with nullcontext () if has_fp16_weights else pytest .raises (RuntimeError ):
129+ new_linear_custom .load_state_dict (new_state_dict , strict = True )
130+
131+ new_linear_custom = new_linear_custom .cuda ()
132+
133+ if not deserialize_before_cuda :
134+ new_linear_custom .load_state_dict (new_state_dict , strict = True )
135+
136+ x_second = x .clone ().cuda ().requires_grad_ (True )
137+ fx_second = new_linear_custom (x_second ).float ()
138+ (fx_second * grad_proj ).mean ().backward ()
139+
140+ # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
141+ if has_fp16_weights or not deserialize_before_cuda :
142+ assert torch .allclose (fx_first , fx_second , atol = 1e-5 )
143+ assert torch .allclose (x_first .grad , x_second .grad , atol = 1e-5 )
0 commit comments