3030from modelopt .torch .quantization .utils import fsdp2_aware_weight_update , patch_fsdp_mp_dtypes
3131
3232
33- @pytest .fixture (autouse = True )
34- def patch_fsdp_dtypes ():
35- """Automatically patch FSDP mixed precision dtypes for all tests in this module."""
36- with patch_fsdp_mp_dtypes ():
37- yield
38-
39-
4033def _update_weight_test (rank , size ):
4134 """Test fsdp2 weight update context for weight update -> only value changed"""
4235 from torch .distributed ._composable .fsdp import fully_shard
4336
44- # Define and shard model
45- model = ToyModel (dims = [4 , 4 ], bias = False ).to ("cuda" )
37+ with patch_fsdp_mp_dtypes ():
38+ # Define and shard model
39+ model = ToyModel (dims = [4 , 4 ], bias = False ).to ("cuda" )
4640
47- assert not torch .equal (
48- model .linears .weight .data ,
49- torch .zeros (4 , 4 ).to (model .linears .weight .device ).to (model .linears .weight .dtype ),
50- )
41+ assert not torch .equal (
42+ model .linears .weight .data ,
43+ torch .zeros (4 , 4 ).to (model .linears .weight .device ).to (model .linears .weight .dtype ),
44+ )
5145
52- fully_shard (model .linears )
53- fully_shard (model )
46+ fully_shard (model .linears )
47+ fully_shard (model )
5448
55- torch .distributed .barrier ()
49+ torch .distributed .barrier ()
5650
57- for name , module in model .named_modules ():
58- if "linears" in name :
59- with fsdp2_aware_weight_update (model , module ):
60- module .weight .data = torch .zeros_like (module .weight .data )
51+ for name , module in model .named_modules ():
52+ if "linears" in name :
53+ with fsdp2_aware_weight_update (model , module ):
54+ module .weight .data = torch .zeros_like (module .weight .data )
6155
62- torch .distributed .barrier ()
63- model .linears .unshard ()
56+ torch .distributed .barrier ()
57+ model .linears .unshard ()
6458
65- # Check if weights are as expected after unshard
66- for param in model .parameters ():
67- assert torch .allclose (
68- torch .zeros (4 , 4 ).to (param .data .device ).to (param .data .dtype ), param .data
69- )
59+ # Check if weights are as expected after unshard
60+ for param in model .parameters ():
61+ assert torch .allclose (
62+ torch .zeros (4 , 4 ).to (param .data .device ).to (param .data .dtype ), param .data
63+ )
7064
71- # Check if forward pass is as expected
72- model .linears .reshard ()
73- output = model (torch .randn (4 , 4 ).to (model .linears .weight .device ))
74- assert torch .allclose (torch .zeros (4 , 4 ).to (output .device ).to (output .dtype ), output )
65+ # Check if forward pass is as expected
66+ model .linears .reshard ()
67+ output = model (torch .randn (4 , 4 ).to (model .linears .weight .device ))
68+ assert torch .allclose (torch .zeros (4 , 4 ).to (output .device ).to (output .dtype ), output )
7569
7670
7771def _compress_weight_test (rank , size ):
7872 """Test fsdp2 weight update context for weight compression -> only value,shape and dtype changed"""
7973 from torch .distributed ._composable .fsdp import fully_shard
8074
81- # Define and shard model
82- model = ToyModel (dims = [6 , 6 ], bias = False ).to ("cuda" )
75+ with patch_fsdp_mp_dtypes ():
76+ # Define and shard model
77+ model = ToyModel (dims = [6 , 6 ], bias = False ).to ("cuda" )
8378
84- assert not torch .equal (
85- model .linears .weight .data ,
86- torch .zeros (6 , 6 ).to (model .linears .weight .device ).to (model .linears .weight .dtype ),
87- )
79+ assert not torch .equal (
80+ model .linears .weight .data ,
81+ torch .zeros (6 , 6 ).to (model .linears .weight .device ).to (model .linears .weight .dtype ),
82+ )
8883
89- fully_shard (model .linears )
90- fully_shard (model )
91- torch .distributed .barrier ()
84+ fully_shard (model .linears )
85+ fully_shard (model )
86+ torch .distributed .barrier ()
9287
93- for name , module in model .named_modules ():
94- if "linears" in name :
95- with fsdp2_aware_weight_update (model , module ):
96- module .weight .data = (
97- torch .zeros (2 , 2 ).to (torch .float8_e4m3fn ).to (module .weight .data .device )
98- )
88+ for name , module in model .named_modules ():
89+ if "linears" in name :
90+ with fsdp2_aware_weight_update (model , module ):
91+ module .weight .data = (
92+ torch .zeros (2 , 2 ).to (torch .float8_e4m3fn ).to (module .weight .data .device )
93+ )
9994
100- torch .distributed .barrier ()
101- model .linears .unshard ()
102- # Check if weights are as expected after unshard
103- for param in model .parameters ():
104- assert param .data .dtype == torch .float8_e4m3fn
95+ torch .distributed .barrier ()
96+ model .linears .unshard ()
97+ # Check if weights are as expected after unshard
98+ for param in model .parameters ():
99+ assert param .data .dtype == torch .float8_e4m3fn
105100
106101
107102def _compare_parameters_and_buffers (model1 , model2 ):
@@ -126,97 +121,99 @@ def _fuse_layers(rank, size, quant_config):
126121
127122 from torch .distributed ._composable .fsdp import fully_shard
128123
129- # Initialize model
130- model = SmallQKVModel (dim = 32 ).to ("cuda" )
131- non_fsdp_model = SmallQKVModel (dim = 32 ).to ("cuda" )
132- non_fsdp_model .load_state_dict (copy .deepcopy (model .state_dict ()))
133- model .eval ()
134- non_fsdp_model .eval ()
124+ with patch_fsdp_mp_dtypes ():
125+ # Initialize model
126+ model = SmallQKVModel (dim = 32 ).to ("cuda" )
127+ non_fsdp_model = SmallQKVModel (dim = 32 ).to ("cuda" )
128+ non_fsdp_model .load_state_dict (copy .deepcopy (model .state_dict ()))
129+ model .eval ()
130+ non_fsdp_model .eval ()
135131
136- _compare_parameters_and_buffers (model , non_fsdp_model )
132+ _compare_parameters_and_buffers (model , non_fsdp_model )
137133
138- # Create calibration data ONCE
139- calib_data = torch .randn (1 , 32 , device = "cuda" )
134+ # Create calibration data ONCE
135+ calib_data = torch .randn (1 , 32 , device = "cuda" )
140136
141- def calib_fn (x ):
142- return x (calib_data )
137+ def calib_fn (x ):
138+ return x (calib_data )
143139
144- # Shard model
145- fully_shard (model )
146- torch .distributed .barrier ()
140+ # Shard model
141+ fully_shard (model )
142+ torch .distributed .barrier ()
147143
148- # Quantize model
149- mtq .quantize (model , quant_config , calib_fn )
150- mtq .quantize (non_fsdp_model , quant_config , calib_fn )
144+ # Quantize model
145+ mtq .quantize (model , quant_config , calib_fn )
146+ mtq .quantize (non_fsdp_model , quant_config , calib_fn )
151147
152- torch .distributed .barrier ()
148+ torch .distributed .barrier ()
153149
154- model .apply_embed = True
155- non_fsdp_model .apply_embed = True
150+ model .apply_embed = True
151+ non_fsdp_model .apply_embed = True
156152
157- requantize_resmooth_fused_llm_layers (model )
158- requantize_resmooth_fused_llm_layers (non_fsdp_model )
153+ requantize_resmooth_fused_llm_layers (model )
154+ requantize_resmooth_fused_llm_layers (non_fsdp_model )
159155
160- torch .distributed .barrier ()
156+ torch .distributed .barrier ()
161157
162- # Unshard model
163- model .unshard ()
158+ # Unshard model
159+ model .unshard ()
164160
165- _compare_parameters_and_buffers (model , non_fsdp_model )
161+ _compare_parameters_and_buffers (model , non_fsdp_model )
166162
167163
168164def _export_quantized_weight_test (rank , size , quant_config ):
169165 import copy
170166
171167 from torch .distributed ._composable .fsdp import fully_shard
172168
173- # Initialize model
174- model = SmallQKVModel (dim = 32 ).to ("cuda" )
175- non_fsdp_model = SmallQKVModel (dim = 32 ).to ("cuda" )
176- non_fsdp_model .load_state_dict (copy .deepcopy (model .state_dict ()))
177- model .eval ()
178- non_fsdp_model .eval ()
179- _compare_parameters_and_buffers (model , non_fsdp_model )
169+ with patch_fsdp_mp_dtypes ():
170+ # Initialize model
171+ model = SmallQKVModel (dim = 32 ).to ("cuda" )
172+ non_fsdp_model = SmallQKVModel (dim = 32 ).to ("cuda" )
173+ non_fsdp_model .load_state_dict (copy .deepcopy (model .state_dict ()))
174+ model .eval ()
175+ non_fsdp_model .eval ()
176+ _compare_parameters_and_buffers (model , non_fsdp_model )
180177
181- # Create calibration data ONCE
182- calib_data = torch .randn (1 , 32 , device = "cuda" )
178+ # Create calibration data ONCE
179+ calib_data = torch .randn (1 , 32 , device = "cuda" )
183180
184- def calib_fn (x ):
185- return x (calib_data )
181+ def calib_fn (x ):
182+ return x (calib_data )
186183
187- # Shard model
188- fully_shard (model )
189- torch .distributed .barrier ()
184+ # Shard model
185+ fully_shard (model )
186+ torch .distributed .barrier ()
190187
191- # Quantize model
192- mtq .quantize (model , quant_config , calib_fn )
193- mtq .quantize (non_fsdp_model , quant_config , calib_fn )
188+ # Quantize model
189+ mtq .quantize (model , quant_config , calib_fn )
190+ mtq .quantize (non_fsdp_model , quant_config , calib_fn )
194191
195- torch .distributed .barrier ()
192+ torch .distributed .barrier ()
196193
197- model .apply_embed = True
198- non_fsdp_model .apply_embed = True
194+ model .apply_embed = True
195+ non_fsdp_model .apply_embed = True
199196
200- requantize_resmooth_fused_llm_layers (model )
201- requantize_resmooth_fused_llm_layers (non_fsdp_model )
197+ requantize_resmooth_fused_llm_layers (model )
198+ requantize_resmooth_fused_llm_layers (non_fsdp_model )
202199
203- torch .distributed .barrier ()
200+ torch .distributed .barrier ()
204201
205- for name , sub_module in model .named_modules ():
206- if is_quantlinear (sub_module ):
207- with fsdp2_aware_weight_update (model , sub_module ):
208- _export_quantized_weight (sub_module , torch .float16 )
202+ for name , sub_module in model .named_modules ():
203+ if is_quantlinear (sub_module ):
204+ with fsdp2_aware_weight_update (model , sub_module ):
205+ _export_quantized_weight (sub_module , torch .float16 )
209206
210- for name , sub_module in non_fsdp_model .named_modules ():
211- if is_quantlinear (sub_module ):
212- with fsdp2_aware_weight_update (non_fsdp_model , sub_module ):
213- _export_quantized_weight (sub_module , torch .float16 )
207+ for name , sub_module in non_fsdp_model .named_modules ():
208+ if is_quantlinear (sub_module ):
209+ with fsdp2_aware_weight_update (non_fsdp_model , sub_module ):
210+ _export_quantized_weight (sub_module , torch .float16 )
214211
215- torch .distributed .barrier ()
216- # Unshard model
217- model .unshard ()
212+ torch .distributed .barrier ()
213+ # Unshard model
214+ model .unshard ()
218215
219- _compare_parameters_and_buffers (model , non_fsdp_model )
216+ _compare_parameters_and_buffers (model , non_fsdp_model )
220217
221218
222219@pytest .mark .parametrize ("device_count" , [2 ])
0 commit comments