Skip to content

Commit f81c370

Browse files
committed
minor update
Signed-off-by: Suguna Velury <[email protected]>
1 parent 6e07dff commit f81c370

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

examples/llm_ptq/multinode-ptq.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@
5050
# Enable HuggingFace checkpointing
5151
mto.enable_huggingface_checkpointing()
5252

53-
original_init_mp_dtypes = patch_fsdp_mp_dtypes()
54-
5553

5654
def parse_args():
5755
"""Parse command line arguments."""
@@ -275,7 +273,7 @@ def export_model(
275273
export_dir.mkdir(parents=True, exist_ok=True)
276274

277275
post_state_dict, hf_quant_config = _export_hf_checkpoint(
278-
model, torch.bfloat16, is_fsdp2=True, accelerator=accelerator
276+
model, torch.bfloat16, accelerator=accelerator
279277
)
280278

281279
if accelerator.is_main_process:
@@ -384,9 +382,6 @@ def main(args):
384382
print(f"Model exported to {args.export_path}")
385383

386384
print("Unpatching FSDP2 MP dtypes")
387-
torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = (
388-
original_init_mp_dtypes
389-
)
390385

391386

392387
if __name__ == "__main__":

modelopt/torch/export/unified_export_hf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ def _export_quantized_weight(
346346
def _export_hf_checkpoint(
347347
model: nn.Module,
348348
dtype: torch.dtype | None = None,
349-
is_fsdp2: bool = False,
350349
accelerator: Accelerator | None = None,
351350
) -> tuple[dict[str, Any], dict[str, Any]]:
352351
"""Exports the torch model to the packed checkpoint with original HF naming.
@@ -356,6 +355,7 @@ def _export_hf_checkpoint(
356355
Args:
357356
model: the torch model.
358357
dtype: the weights data type to export the unquantized layers or the default model data type if None.
358+
accelerator: the accelerator instance in case of distributed export setup.
359359
360360
Returns:
361361
post_state_dict: Dict containing quantized weights
@@ -493,8 +493,7 @@ def _export_hf_checkpoint(
493493
with fsdp2_aware_weight_update(model, sub_module):
494494
_export_quantized_weight(sub_module, dtype, weight_name)
495495

496-
if is_fsdp2:
497-
assert accelerator is not None, "Accelerator is required for FSDP2 export"
496+
if accelerator is not None:
498497
# Gather state_dict from all ranks
499498
quantized_state_dict = accelerator.get_state_dict(model)
500499
else:

0 commit comments

Comments
 (0)