Skip to content

Commit 0f12136

Browse files
committed
updated context manager to handle weight and bias
Signed-off-by: Suguna Velury <[email protected]>
1 parent c1c5ca0 commit 0f12136

File tree

1 file changed

+47
-41
lines changed

1 file changed

+47
-41
lines changed

modelopt/torch/quantization/utils.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def get_prefixed_param_names(parent_model, target_module):
577577

578578

579579
def create_fsdp_param_mapping(fsdp_param_list, model):
580-
"""Builds a mapping from module name to their corresponding FSDPParam.
580+
"""Builds a mapping from full parameter name to their corresponding FSDPParam.
581581
582582
Args:
583583
fsdp_param_list (list): List of FSDPParam.
@@ -586,10 +586,16 @@ def create_fsdp_param_mapping(fsdp_param_list, model):
586586
Returns:
587587
dict: Full parameter name → FSDP parameter.
588588
"""
589-
return {
590-
get_prefixed_param_names(model, param._module_info.module): param
591-
for param in fsdp_param_list
592-
}
589+
mapping = {}
590+
for param in fsdp_param_list:
591+
# Get the module name
592+
module_name = get_prefixed_param_names(model, param._module_info.module)
593+
if module_name is not None:
594+
# Get the parameter name from _module_info and construct full param name
595+
param_name = param._module_info.param_name
596+
full_param_name = f"{module_name}.{param_name}"
597+
mapping[full_param_name] = param
598+
return mapping
593599

594600

595601
@contextmanager
@@ -718,44 +724,44 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
718724
if isinstance(root_model, FSDPModule):
719725
# Update FSDPParam list
720726
for module in modules_to_update:
721-
name = _get_module_name(module, root_model)
722-
if name not in fsdp_param_mapping:
723-
continue
724-
725-
old_fsdp_param = fsdp_param_mapping[name]
726-
727-
# Update mp policy to reflect the new dtype
728-
new_mp_policy = MixedPrecisionPolicy(
729-
param_dtype=module.weight.dtype,
730-
reduce_dtype=None,
731-
output_dtype=None,
732-
cast_forward_inputs=False,
733-
)
734-
735-
with no_requires_grad():
736-
# Create a new QFSDPParam or FSDPParam based on weight type
737-
param_class = (
738-
QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam
739-
)
740-
741-
new_param = param_class(
742-
module.weight,
743-
old_fsdp_param._module_info,
744-
old_fsdp_param.mesh_info,
745-
old_fsdp_param.post_forward_mesh_info,
746-
old_fsdp_param.device,
747-
None,
748-
new_mp_policy,
749-
None,
727+
for n, p in module.named_parameters():
728+
name = _get_module_name(module, root_model)
729+
name = f"{name}.{n}"
730+
if name not in fsdp_param_mapping:
731+
continue
732+
733+
old_fsdp_param = fsdp_param_mapping[name]
734+
735+
# Update mp policy to reflect the new dtype
736+
new_mp_policy = MixedPrecisionPolicy(
737+
param_dtype=p.dtype,
738+
reduce_dtype=None,
739+
output_dtype=None,
740+
cast_forward_inputs=False,
750741
)
751-
if not isinstance(new_param, QFSDPParam):
752-
new_param.init_dtype_attrs(new_mp_policy)
753-
754-
# Update the FSDPParam mapping to keep track of the new FSDPParam
755-
fsdp_param_mapping[name] = new_param
756742

757-
# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
758-
old_fsdp_param._post_load_hook_handle.remove()
743+
with no_requires_grad():
744+
# Create a new QFSDPParam or FSDPParam based on weight type
745+
param_class = QFSDPParam if isinstance(p, QTensorWrapper) else FSDPParam
746+
747+
new_param = param_class(
748+
p,
749+
old_fsdp_param._module_info,
750+
old_fsdp_param.mesh_info,
751+
old_fsdp_param.post_forward_mesh_info,
752+
old_fsdp_param.device,
753+
None,
754+
new_mp_policy,
755+
None,
756+
)
757+
if not isinstance(new_param, QFSDPParam):
758+
new_param.init_dtype_attrs(new_mp_policy)
759+
760+
# Update the FSDPParam mapping to keep track of the new FSDPParam
761+
fsdp_param_mapping[name] = new_param
762+
763+
# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
764+
old_fsdp_param._post_load_hook_handle.remove()
759765

760766
# Update FSDPParam list with new compressed weights
761767
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())

0 commit comments

Comments
 (0)