@@ -274,76 +274,88 @@ def fsdp2_aware_weight_update(root_model, modules_to_update):
274274
275275 from modelopt .torch .quantization .utils import _get_enclosing_fsdp_module , _get_module_name
276276
277- breakpoint ()
278- # Get FSDP root module, if none is returned, then the update is not made to a submodule of an FSDPModule
279- if not isinstance (modules_to_update , list ):
280- modules_to_update = [modules_to_update ]
281-
282- root_modules = set ()
283- for module in modules_to_update :
284- root_module = _get_enclosing_fsdp_module (module , root_model )
285- root_modules .add (root_module )
286-
287- # Ensure all modules in root_modules are the same
288- assert len (root_modules ) == 1 , "All modules must be in the same root FSDPModule"
289- root_module = next (iter (root_modules ))
290-
291- # Check if root module state is sharded and unshard if needed
292- if fully_shard .state (root_module )._fsdp_param_group .is_sharded :
293- with enable_fake_quant (root_module ):
294- root_module .unshard ()
295-
296- # Get FSDPParam list
297- fsdp_param_group = fully_shard .state (root_module )._fsdp_param_group
298- fsdp_param_mapping = _create_fsdp_param_mapping (fsdp_param_group .fsdp_params , root_module )
299-
300- # Assert that all the modules in the module list are present in this fsdp_param_group
301- for module in modules_to_update :
302- name = _get_module_name (module , root_module )
303- assert name in fsdp_param_mapping , f"Module { module } not found in fsdp_param_mapping"
277+ if isinstance (root_model , FSDPModule ):
278+ # Get FSDP root module, if none is returned, then the update is not made to a submodule of an FSDPModule
279+ if not isinstance (modules_to_update , list ):
280+ modules_to_update = [modules_to_update ]
281+
282+ root_modules = set ()
283+ for module in modules_to_update :
284+ root_module = _get_enclosing_fsdp_module (module , root_model )
285+ root_modules .add (root_module )
286+
287+ # Ensure all modules in root_modules are the same
288+ assert len (root_modules ) == 1 , "All modules must be in the same root FSDPModule"
289+ root_module = next (iter (root_modules ))
290+
291+ # Check if root module state is sharded and unshard if needed
292+ if fully_shard .state (root_module )._fsdp_param_group .is_sharded :
293+ with enable_fake_quant (root_module ):
294+ root_module .unshard ()
295+
296+ # Get FSDPParam list
297+ fsdp_param_group = fully_shard .state (root_module )._fsdp_param_group
298+ fsdp_param_mapping = _create_fsdp_param_mapping (
299+ fsdp_param_group .fsdp_params , root_model
300+ )
304301
302+ # Assert that all the modules in the module list are present in this fsdp_param_group
303+ for module in modules_to_update :
304+ name = _get_module_name (module , root_model )
305+ assert name in fsdp_param_mapping , (
306+ f"Module { module } not found in fsdp_param_mapping"
307+ )
305308 # Yields for necessary weight updates/processing
306309 yield
307310 finally :
308- # Update FSDPParam list
309- for module in modules_to_update :
310- name = _get_module_name (module , root_module )
311- old_fsdp_param = fsdp_param_mapping [name ]
312-
313- # Update mp policy to reflect the new dtype
314- new_mp_policy = MixedPrecisionPolicy (
315- param_dtype = module .weight .dtype ,
316- reduce_dtype = None ,
317- output_dtype = None ,
318- cast_forward_inputs = False ,
319- )
311+ from torch .distributed .fsdp import fully_shard
320312
321- with no_requires_grad ():
322- # Create a new QFSDPParam or FSDPParam based on weight type
323- param_class = QFSDPParam if isinstance (module .weight , QTensorWrapper ) else FSDPParam
324- new_param = param_class (
325- module .weight ,
326- old_fsdp_param ._module_info ,
327- old_fsdp_param .mesh_info ,
328- old_fsdp_param .post_forward_mesh_info ,
329- old_fsdp_param .device ,
330- None ,
331- new_mp_policy ,
332- None ,
313+ from modelopt .torch .quantization .utils import _get_enclosing_fsdp_module , _get_module_name
314+
315+ if isinstance (root_model , FSDPModule ):
316+ # Update FSDPParam list
317+ for module in modules_to_update :
318+ name = _get_module_name (module , root_model )
319+ old_fsdp_param = fsdp_param_mapping [name ]
320+
321+ # Update mp policy to reflect the new dtype
322+ new_mp_policy = MixedPrecisionPolicy (
323+ param_dtype = module .weight .dtype ,
324+ reduce_dtype = None ,
325+ output_dtype = None ,
326+ cast_forward_inputs = False ,
333327 )
334328
335- # Update the FSDPParam mapping to keep track of the new FSDPParam
336- fsdp_param_mapping [name ] = new_param
329+ with no_requires_grad ():
330+ # Create a new QFSDPParam or FSDPParam based on weight type
331+ param_class = (
332+ QFSDPParam if isinstance (module .weight , QTensorWrapper ) else FSDPParam
333+ )
334+ new_param = param_class (
335+ module .weight ,
336+ old_fsdp_param ._module_info ,
337+ old_fsdp_param .mesh_info ,
338+ old_fsdp_param .post_forward_mesh_info ,
339+ old_fsdp_param .device ,
340+ None ,
341+ new_mp_policy ,
342+ None ,
343+ )
344+ if not isinstance (new_param , QFSDPParam ):
345+ new_param .init_dtype_attrs (new_mp_policy )
346+
347+ # Update the FSDPParam mapping to keep track of the new FSDPParam
348+ fsdp_param_mapping [name ] = new_param
337349
338- # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
339- old_fsdp_param ._post_load_hook_handle .remove ()
350+ # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
351+ old_fsdp_param ._post_load_hook_handle .remove ()
340352
341- # Update FSDPParam list with new compressed weights
342- fsdp_param_group .fsdp_params = list (fsdp_param_mapping .values ())
353+ # Update FSDPParam list with new compressed weights
354+ fsdp_param_group .fsdp_params = list (fsdp_param_mapping .values ())
343355
344- # Reshard FSDP root module
345- # TODO: Check if reshard is needed or not
346- root_module .reshard ()
356+ # Reshard FSDP root module
357+ # TODO: Check if reshard is needed or not
358+ root_module .reshard ()
347359
348360
349361def pack_real_quantize_weight (module , force_quantize : bool = False ):
@@ -422,39 +434,8 @@ def _compress_fsdp_module(fsdp_module):
422434 if name not in fsdp_param_mapping :
423435 continue
424436
425- if _compress_and_update_module_weight (submodule ):
426- old_fsdp_param = fsdp_param_mapping [name ]
427-
428- # Update mp policy to reflect the new dtype
429- new_mp_policy = MixedPrecisionPolicy (
430- param_dtype = submodule .weight .dtype ,
431- reduce_dtype = None ,
432- output_dtype = None ,
433- cast_forward_inputs = False ,
434- )
435- with no_requires_grad ():
436- # Create a new QFSDPParam parameter
437- new_param = QFSDPParam (
438- submodule .weight ,
439- old_fsdp_param ._module_info ,
440- old_fsdp_param .mesh_info ,
441- old_fsdp_param .post_forward_mesh_info ,
442- old_fsdp_param .device ,
443- None ,
444- new_mp_policy ,
445- None ,
446- )
447-
448- # Update the FSDPParam mapping to keep track of the new FSDPParam
449- fsdp_param_mapping [name ] = new_param
450- # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
451- old_fsdp_param ._post_load_hook_handle .remove ()
452-
453- # Update FSDPParam list with new compressed weights
454- fsdp_param_group .fsdp_params = list (fsdp_param_mapping .values ())
455-
456- # Reshard FSDP root module
457- fsdp_module .reshard ()
437+ with fsdp2_aware_weight_update (fsdp_module , submodule ):
438+ _compress_and_update_module_weight (submodule )
458439
459440 with SequentialQuantizer .convert_to_single_quantizer (module ), torch .no_grad ():
460441 for _ , m in module .named_modules ():
0 commit comments