Skip to content

Commit ffbee8f

Browse files
committed
Refactor
Signed-off-by: Suguna Velury <[email protected]>
1 parent 6e24e68 commit ffbee8f

File tree

5 files changed

+202
-239
lines changed

5 files changed

+202
-239
lines changed

examples/llm_ptq/multinode-ptq.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import modelopt.torch.opt as mto
2222
import modelopt.torch.quantization as mtq
2323
from modelopt.torch.export import get_model_type
24-
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
25-
from modelopt.torch.export.quant_utils import postprocess_state_dict
26-
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
24+
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
2725
from modelopt.torch.quantization.config import need_calibration
2826
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
2927
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets
@@ -121,11 +119,6 @@ def parse_args():
121119
action="store_true",
122120
help="Trust remote code for HuggingFace models",
123121
)
124-
parser.add_argument(
125-
"--attn_implementation",
126-
type=str,
127-
help="Attention implementation to use (passed to HF model loading)",
128-
)
129122
parser.add_argument("--awq_block_size", default=0, type=int)
130123

131124
args = parser.parse_args()
@@ -159,6 +152,8 @@ def load_and_prepare_model(
159152
)
160153
model.eval()
161154
model_type = get_model_type(model)
155+
# Need the original architectures for export
156+
# FSDP prefix is added to the architectures for FSDP2 wrapped models
162157
original_architectures = model.config.architectures
163158

164159
# FSDP2 requires an optimizer to be prepared together with the model
@@ -274,6 +269,8 @@ def calibrate(unwrapped_model):
274269
for k, v in batch.items()
275270
}
276271
# Use outer model (FSDP-wrapped), not the parameter
272+
# Important: We should forward pass using the unwrapped model
273+
# mtq.quantize will unwrap the model & pass to the forward_loop
277274
model(**batch)
278275

279276
return calibrate
@@ -293,41 +290,27 @@ def export_model(
293290
export_path: Directory to export model to
294291
"""
295292
export_dir = Path(export_path)
296-
export_dir.mkdir(parents=True, exist_ok=True)
297293

298294
# Get quantization config
299-
_, hf_quant_config = _export_hf_checkpoint(model, dtype=torch.bfloat16)
300-
301-
# Gather and post-process state dict
302-
model_state_dict = accelerator.get_state_dict(model)
303-
post_state_dict = postprocess_state_dict(model_state_dict, 1.0, None)
304-
305-
# Save quantization config
306-
if accelerator.is_main_process:
307-
with open(export_dir / "hf_quant_config.json", "w") as f:
308-
json.dump(hf_quant_config, f, indent=4)
309-
310-
# Convert config format
311-
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
312-
313-
# Save model
314-
model.save_pretrained(
315-
export_dir,
316-
state_dict=post_state_dict,
317-
save_modelopt_state=False,
318-
)
295+
export_hf_checkpoint(
296+
model,
297+
dtype=torch.bfloat16,
298+
export_dir=export_dir,
299+
save_modelopt_state=False,
300+
is_fsdp2=True,
301+
accelerator=accelerator,
302+
)
319303

320-
# Update config with quantization info
321-
config_path = export_dir / "config.json"
322-
with open(config_path) as f:
323-
config_data = json.load(f)
304+
# Update config with quantization info
305+
config_path = export_dir / "config.json"
306+
with open(config_path) as f:
307+
config_data = json.load(f)
324308

325-
config_data["quantization_config"] = hf_quant_config
326-
# Update architectures with original architecture. FSDP prefix must be removed for FSDP wrapped models.
327-
config_data["architectures"] = architectures
309+
# Update architectures with original architecture. FSDP prefix must be removed for FSDP wrapped models.
310+
config_data["architectures"] = architectures
328311

329-
with open(config_path, "w") as f:
330-
json.dump(config_data, f, indent=4)
312+
with open(config_path, "w") as f:
313+
json.dump(config_data, f, indent=4)
331314

332315

333316
def main(args):
@@ -402,10 +385,13 @@ def main(args):
402385
print(f"Quantization completed in {elapsed:.2f}s")
403386
mtq.print_quant_summary(model)
404387

388+
start_time = time.time()
405389
export_model(model, accelerator, args.export_path, original_architectures)
390+
elapsed = time.time() - start_time
406391

407392
if accelerator.is_main_process:
408393
# Export the model
394+
print(f"Export completed in {elapsed:.2f}s")
409395
print(f"Model exported to {args.export_path}")
410396

411397
print("Unpatching FSDP2 MP dtypes")

modelopt/torch/export/layer_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,7 @@ def is_moe(module: nn.Module) -> bool:
345345

346346
def is_quantlinear(module: nn.Module) -> bool:
347347
"""Returns whether the module is a quantized linear layer."""
348-
return (
349-
"QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower()
350-
) or ("Quant" in type(module).__name__ and "Linear" in type(module).__name__)
348+
return "QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower()
351349

352350

353351
def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:

modelopt/torch/export/unified_export_hf.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626

2727
import torch
2828
import torch.nn as nn
29+
from accelerate import Accelerator
2930
from safetensors.torch import save_file
3031

3132
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3233
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
3334
from modelopt.torch.quantization.qtensor import NVFP4QTensor
34-
from modelopt.torch.quantization.qtensor.base_qtensor import fsdp2_aware_weight_update
35-
from modelopt.torch.quantization.utils import quantizer_attr_names
35+
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
3636

3737
from .convert_hf_config import convert_hf_quant_config_format
3838
from .layer_utils import (
@@ -344,7 +344,10 @@ def _export_quantized_weight(
344344

345345

346346
def _export_hf_checkpoint(
347-
model: nn.Module, dtype: torch.dtype | None = None
347+
model: nn.Module,
348+
dtype: torch.dtype | None = None,
349+
is_fsdp2: bool = False,
350+
accelerator: Accelerator | None = None,
348351
) -> tuple[dict[str, Any], dict[str, Any]]:
349352
"""Exports the torch model to the packed checkpoint with original HF naming.
350353
@@ -490,7 +493,11 @@ def _export_hf_checkpoint(
490493
with fsdp2_aware_weight_update(model, sub_module):
491494
_export_quantized_weight(sub_module, dtype, weight_name)
492495

493-
quantized_state_dict = model.state_dict()
496+
if is_fsdp2:
497+
assert accelerator is not None, "Accelerator is required for FSDP2 export"
498+
quantized_state_dict = accelerator.get_state_dict(model)
499+
else:
500+
quantized_state_dict = model.state_dict()
494501

495502
quantized_state_dict = postprocess_state_dict(
496503
quantized_state_dict, kv_cache_max_bound, kv_cache_format
@@ -508,6 +515,8 @@ def export_hf_checkpoint(
508515
dtype: torch.dtype | None = None,
509516
export_dir: Path | str = tempfile.gettempdir(),
510517
save_modelopt_state: bool = False,
518+
is_fsdp2: bool = False,
519+
accelerator: Accelerator | None = None,
511520
):
512521
"""Exports the torch model to unified checkpoint and saves to export_dir.
513522
@@ -529,7 +538,9 @@ def export_hf_checkpoint(
529538
return
530539

531540
try:
532-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
541+
post_state_dict, hf_quant_config = _export_hf_checkpoint(
542+
model, dtype, is_fsdp2, accelerator
543+
)
533544

534545
# Save hf_quant_config.json for backward compatibility
535546
with open(f"{export_dir}/hf_quant_config.json", "w") as file:

0 commit comments

Comments
 (0)