@@ -118,11 +118,11 @@ def _compare_parameters_and_buffers(model1, model2):
118118 )
119119
120120
121- def _fuse_layers (rank , size , quant_config ):
121+ def _fuse_layers (rank , size , quant_config , bias ):
122122 with patch_fsdp_mp_dtypes ():
123123 # Initialize model
124- model = SmallQKVModel (dim = 32 ).to ("cuda" )
125- non_fsdp_model = SmallQKVModel (dim = 32 ).to ("cuda" )
124+ model = SmallQKVModel (dim = 32 , bias = bias ).to ("cuda" )
125+ non_fsdp_model = SmallQKVModel (dim = 32 , bias = bias ).to ("cuda" )
126126 non_fsdp_model .load_state_dict (copy .deepcopy (model .state_dict ()))
127127 model .eval ()
128128 non_fsdp_model .eval ()
@@ -159,15 +159,15 @@ def calib_fn(x):
159159 _compare_parameters_and_buffers (model , non_fsdp_model )
160160
161161
162- def _export_quantized_weight_test (rank , size , quant_config ):
162+ def _export_quantized_weight_test (rank , size , quant_config , bias ):
163163 import copy
164164
165165 from torch .distributed ._composable .fsdp import fully_shard
166166
167167 with patch_fsdp_mp_dtypes ():
168168 # Initialize model
169- model = SmallQKVModel (dim = 32 ).to ("cuda" )
170- non_fsdp_model = SmallQKVModel (dim = 32 ).to ("cuda" )
169+ model = SmallQKVModel (dim = 32 , bias = bias ).to ("cuda" )
170+ non_fsdp_model = SmallQKVModel (dim = 32 , bias = bias ).to ("cuda" )
171171 non_fsdp_model .load_state_dict (copy .deepcopy (model .state_dict ()))
172172 model .eval ()
173173 non_fsdp_model .eval ()
@@ -247,10 +247,11 @@ def test_fsdp2_weight_update_context_for_export(device_count):
247247 ],
248248)
249249@pytest .mark .parametrize ("device_count" , get_device_counts ())
250- def test_fsdp2_weight_update_context_for_fuse_layers (device_count , quant_config ):
250+ @pytest .mark .parametrize ("bias" , [True , False ])
251+ def test_fsdp2_weight_update_context_for_fuse_layers (device_count , quant_config , bias ):
251252 spawn_multiprocess_job (
252253 size = device_count ,
253- job = partial (_fuse_layers , quant_config = quant_config ),
254+ job = partial (_fuse_layers , quant_config = quant_config , bias = bias ),
254255 backend = "nccl" ,
255256 )
256257
@@ -270,9 +271,10 @@ def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config)
270271 ],
271272)
272273@pytest .mark .parametrize ("device_count" , get_device_counts ())
273- def test_fsdp2_weight_update_context_for_export_quantized_weight (device_count , quant_config ):
274+ @pytest .mark .parametrize ("bias" , [True , False ])
275+ def test_fsdp2_weight_update_context_for_export_quantized_weight (device_count , quant_config , bias ):
274276 spawn_multiprocess_job (
275277 size = device_count ,
276- job = partial (_export_quantized_weight_test , quant_config = quant_config ),
278+ job = partial (_export_quantized_weight_test , quant_config = quant_config , bias = bias ),
277279 backend = "nccl" ,
278280 )
0 commit comments