Skip to content

Commit 417f17a

Browse files
committed
added optimization for export and extra note on performance
Signed-off-by: Suguna Velury <[email protected]>
1 parent d95e2fd commit 417f17a

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

examples/llm_ptq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ accelerate launch --config_file fsdp2.yaml \
265265

266266
The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document.
267267

268-
> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory.*
268+
> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.*
269269
>
270270
## Framework Scripts
271271

modelopt/torch/export/unified_export_hf.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
except ImportError: # pragma: no cover
3333
Accelerator = None
3434
from safetensors.torch import save_file
35+
from torch.distributed.fsdp import FSDPModule
3536

3637
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3738
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
@@ -350,7 +351,7 @@ def _export_quantized_weight(
350351
def _export_hf_checkpoint(
351352
model: nn.Module,
352353
dtype: torch.dtype | None = None,
353-
accelerator: Accelerator | None = None,
354+
**kwargs,
354355
) -> tuple[dict[str, Any], dict[str, Any]]:
355356
"""Exports the torch model to the packed checkpoint with original HF naming.
356357
@@ -373,6 +374,8 @@ def _export_hf_checkpoint(
373374
f"({dtype}), which may lead to numerical errors."
374375
)
375376

377+
accelerator = kwargs.get("accelerator")
378+
376379
# Create a model layer pool
377380
# If `model.model` exists use that, otherwise use `model` itself, e.g., Nemotron-H
378381
root = getattr(model, "model", model)
@@ -470,12 +473,21 @@ def _export_hf_checkpoint(
470473

471474
# Track if any layers are quantized to properly set exclude_modules
472475
has_quantized_layers = False
476+
fsdp_module_to_reshard = None
473477

474478
for name, sub_module in layer_pool.items():
479+
# Optimization to perform resharding only once per decoder layer to avoid extra communication overhead
480+
if isinstance(sub_module, FSDPModule):
481+
# Every time we encounter a new FSDPModule, we need to reshard the previous one
482+
if fsdp_module_to_reshard is not None:
483+
fsdp_module_to_reshard.reshard()
484+
485+
fsdp_module_to_reshard = sub_module
486+
475487
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
476488
has_quantized_layers = True
477489
if is_quantlinear(sub_module):
478-
with fsdp2_aware_weight_update(model, sub_module):
490+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
479491
_export_quantized_weight(sub_module, dtype)
480492
elif (
481493
"Llama4TextExperts" in type(sub_module).__name__
@@ -494,7 +506,7 @@ def _export_hf_checkpoint(
494506
)
495507
# Export the quantized weights
496508
for weight_name in ["gate_up_proj", "down_proj"]:
497-
with fsdp2_aware_weight_update(model, sub_module):
509+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
498510
_export_quantized_weight(sub_module, dtype, weight_name)
499511

500512
if accelerator is not None:

modelopt/torch/quantization/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def enable_fake_quant(module):
593593

594594

595595
@contextmanager
596-
def fsdp2_aware_weight_update(root_model, modules_to_update):
596+
def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
597597
"""Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule."""
598598
try:
599599
if isinstance(root_model, FSDPModule):
@@ -672,5 +672,5 @@ def fsdp2_aware_weight_update(root_model, modules_to_update):
672672
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())
673673

674674
# Reshard FSDP root module
675-
# TODO: Add a check to reshard only if necessary, can help performance during export
676-
root_module.reshard()
675+
if reshard:
676+
root_module.reshard()

0 commit comments

Comments
 (0)