11from contextlib import nullcontext
2+ from io import BytesIO
23import os
34from tempfile import TemporaryDirectory
45
@@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
6566 assert linear_custom .state .CB is not None
6667 assert linear_custom .state .CxB is None
6768
69+ def torch_save_to_buffer (obj ):
70+ buffer = BytesIO ()
71+ torch .save (obj , buffer )
72+ buffer .seek (0 )
73+ return buffer
74+
75+ def torch_load_from_buffer (buffer ):
76+ buffer .seek (0 )
77+ obj = torch .load (buffer )
78+ buffer .seek (0 )
79+ return obj
6880
6981@pytest .mark .parametrize ("has_fp16_weights" , TRUE_FALSE , ids = id_formatter ("has_fp16_weights" ))
7082@pytest .mark .parametrize ("serialize_before_forward" , TRUE_FALSE , ids = id_formatter ("serialize_before_forward" ))
7183@pytest .mark .parametrize ("deserialize_before_cuda" , TRUE_FALSE , ids = id_formatter ("deserialize_before_cuda" ))
7284@pytest .mark .parametrize ("force_no_igemmlt" , TRUE_FALSE , ids = id_formatter ("force_no_igemmlt" ))
73- def test_linear_serialization (has_fp16_weights , serialize_before_forward , deserialize_before_cuda , force_no_igemmlt ):
85+ @pytest .mark .parametrize ("save_before_forward" , TRUE_FALSE , ids = id_formatter ("save_before_forward" ))
86+ @pytest .mark .parametrize ("load_before_cuda" , TRUE_FALSE , ids = id_formatter ("load_before_cuda" ))
87+ def test_linear_serialization (has_fp16_weights , serialize_before_forward , deserialize_before_cuda , force_no_igemmlt , save_before_forward , load_before_cuda ):
7488 linear = torch .nn .Linear (32 , 96 )
7589 x = torch .randn (3 , 32 , dtype = torch .half )
7690
@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
93107 if serialize_before_forward :
94108 state_dict_8bit = linear_custom .state_dict ()
95109
110+ if save_before_forward :
111+ bytes_8bit = torch_save_to_buffer (linear_custom )
112+
96113 x_first = x .clone ().cuda ().requires_grad_ (True )
97114 fx_first = linear_custom (x_first ).float ()
98115 grad_proj = torch .randn_like (fx_first )
@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
101118 if not serialize_before_forward :
102119 state_dict_8bit = linear_custom .state_dict ()
103120
121+ if not save_before_forward :
122+ bytes_8bit = torch_save_to_buffer (linear_custom )
123+
104124 with TemporaryDirectory () as tmpdir :
105125 state_path_8bit = os .path .join (tmpdir , "state_8bit.pth" )
106126 state_path = os .path .join (tmpdir , "state.pth" )
@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
127147 with nullcontext () if has_fp16_weights else pytest .raises (RuntimeError ):
128148 new_linear_custom .load_state_dict (new_state_dict , strict = True )
129149
150+ if load_before_cuda :
151+ new_linear_custom2 = torch_load_from_buffer (bytes_8bit )
152+
130153 new_linear_custom = new_linear_custom .cuda ()
131154
132155 if not deserialize_before_cuda :
133156 new_linear_custom .load_state_dict (new_state_dict , strict = True )
134157
158+ if not load_before_cuda :
159+ new_linear_custom2 = torch_load_from_buffer (bytes_8bit )
160+
135161 x_second = x .clone ().cuda ().requires_grad_ (True )
136162 fx_second = new_linear_custom (x_second ).float ()
137163 (fx_second * grad_proj ).mean ().backward ()
138164
165+ x_third = x .clone ().cuda ().requires_grad_ (True )
166+ fx_third = new_linear_custom2 (x_third ).float ()
167+ (fx_third * grad_proj ).mean ().backward ()
168+
139169 # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
140170 if has_fp16_weights or not deserialize_before_cuda :
141171 assert torch .allclose (fx_first , fx_second , atol = 1e-5 )
142172 assert torch .allclose (x_first .grad , x_second .grad , atol = 1e-5 )
173+ assert torch .allclose (fx_first , fx_third , atol = 1e-5 )
174+ assert torch .allclose (x_first .grad , x_third .grad , atol = 1e-5 )
0 commit comments