@@ -577,7 +577,7 @@ def get_prefixed_param_names(parent_model, target_module):
577577
578578
579579def create_fsdp_param_mapping (fsdp_param_list , model ):
580- """Builds a mapping from module name to their corresponding FSDPParam.
580+ """Builds a mapping from full parameter name to their corresponding FSDPParam.
581581
582582 Args:
583583 fsdp_param_list (list): List of FSDPParam.
@@ -586,10 +586,16 @@ def create_fsdp_param_mapping(fsdp_param_list, model):
586586 Returns:
587587 dict: Full parameter name → FSDP parameter.
588588 """
589- return {
590- get_prefixed_param_names (model , param ._module_info .module ): param
591- for param in fsdp_param_list
592- }
589+ mapping = {}
590+ for param in fsdp_param_list :
591+ # Get the module name
592+ module_name = get_prefixed_param_names (model , param ._module_info .module )
593+ if module_name is not None :
594+ # Get the parameter name from _module_info and construct full param name
595+ param_name = param ._module_info .param_name
596+ full_param_name = f"{ module_name } .{ param_name } "
597+ mapping [full_param_name ] = param
598+ return mapping
593599
594600
595601@contextmanager
@@ -718,44 +724,44 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
718724 if isinstance (root_model , FSDPModule ):
719725 # Update FSDPParam list
720726 for module in modules_to_update :
721- name = _get_module_name (module , root_model )
722- if name not in fsdp_param_mapping :
723- continue
724-
725- old_fsdp_param = fsdp_param_mapping [name ]
726-
727- # Update mp policy to reflect the new dtype
728- new_mp_policy = MixedPrecisionPolicy (
729- param_dtype = module .weight .dtype ,
730- reduce_dtype = None ,
731- output_dtype = None ,
732- cast_forward_inputs = False ,
733- )
734-
735- with no_requires_grad ():
736- # Create a new QFSDPParam or FSDPParam based on weight type
737- param_class = (
738- QFSDPParam if isinstance (module .weight , QTensorWrapper ) else FSDPParam
739- )
740-
741- new_param = param_class (
742- module .weight ,
743- old_fsdp_param ._module_info ,
744- old_fsdp_param .mesh_info ,
745- old_fsdp_param .post_forward_mesh_info ,
746- old_fsdp_param .device ,
747- None ,
748- new_mp_policy ,
749- None ,
727+ for n , p in module .named_parameters ():
728+ name = _get_module_name (module , root_model )
729+ name = f"{ name } .{ n } "
730+ if name not in fsdp_param_mapping :
731+ continue
732+
733+ old_fsdp_param = fsdp_param_mapping [name ]
734+
735+ # Update mp policy to reflect the new dtype
736+ new_mp_policy = MixedPrecisionPolicy (
737+ param_dtype = p .dtype ,
738+ reduce_dtype = None ,
739+ output_dtype = None ,
740+ cast_forward_inputs = False ,
750741 )
751- if not isinstance (new_param , QFSDPParam ):
752- new_param .init_dtype_attrs (new_mp_policy )
753-
754- # Update the FSDPParam mapping to keep track of the new FSDPParam
755- fsdp_param_mapping [name ] = new_param
756742
757- # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
758- old_fsdp_param ._post_load_hook_handle .remove ()
743+ with no_requires_grad ():
744+ # Create a new QFSDPParam or FSDPParam based on weight type
745+ param_class = QFSDPParam if isinstance (p , QTensorWrapper ) else FSDPParam
746+
747+ new_param = param_class (
748+ p ,
749+ old_fsdp_param ._module_info ,
750+ old_fsdp_param .mesh_info ,
751+ old_fsdp_param .post_forward_mesh_info ,
752+ old_fsdp_param .device ,
753+ None ,
754+ new_mp_policy ,
755+ None ,
756+ )
757+ if not isinstance (new_param , QFSDPParam ):
758+ new_param .init_dtype_attrs (new_mp_policy )
759+
760+ # Update the FSDPParam mapping to keep track of the new FSDPParam
761+ fsdp_param_mapping [name ] = new_param
762+
763+ # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
764+ old_fsdp_param ._post_load_hook_handle .remove ()
759765
760766 # Update FSDPParam list with new compressed weights
761767 fsdp_param_group .fsdp_params = list (fsdp_param_mapping .values ())
0 commit comments