Skip to content

Commit 067f02d

Browse files
committed
refactor
Signed-off-by: Suguna Velury <[email protected]>
1 parent 91a9505 commit 067f02d

File tree

3 files changed

+30
-46
lines changed

3 files changed

+30
-46
lines changed

examples/llm_ptq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ ModelOpt enables quantization of LLMs across multiple GPU nodes using various qu
241241

242242
### Usage
243243

244-
For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized based on your specific requirements.
244+
For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
245245

246246
On each node run the following command:
247247

examples/llm_ptq/multinode-ptq.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import modelopt.torch.opt as mto
2222
import modelopt.torch.quantization as mtq
2323
from modelopt.torch.export import get_model_type
24-
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
24+
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
25+
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
2526
from modelopt.torch.quantization.config import need_calibration
2627
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
2728
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets
@@ -30,18 +31,11 @@
3031
RAND_SEED = 1234
3132

3233
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
33-
"int8": mtq.INT8_DEFAULT_CFG,
34-
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
3534
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
3635
"fp8": mtq.FP8_DEFAULT_CFG,
3736
"int4_awq": mtq.INT4_AWQ_CFG,
38-
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
3937
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
4038
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
41-
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
42-
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
43-
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
44-
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
4539
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
4640
}
4741

@@ -52,18 +46,6 @@
5246
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
5347
}
5448

55-
SUPPORTED_QFORMATS = [
56-
"int8_wo",
57-
"int4_awq",
58-
"fp8",
59-
"nvfp4",
60-
"nvfp4_awq",
61-
"w4a8_awq",
62-
"fp8_pb_wo",
63-
"w4a8_mxfp4_fp8",
64-
"nvfp4_mlp_only",
65-
]
66-
6749

6850
# Enable HuggingFace checkpointing
6951
mto.enable_huggingface_checkpointing()
@@ -83,7 +65,7 @@ def parse_args():
8365
parser.add_argument(
8466
"--qformat",
8567
default="fp8",
86-
choices=SUPPORTED_QFORMATS,
68+
choices=QUANT_CFG_CHOICES.keys(),
8769
help="Quantization format",
8870
)
8971
parser.add_argument(
@@ -290,27 +272,32 @@ def export_model(
290272
export_path: Directory to export model to
291273
"""
292274
export_dir = Path(export_path)
275+
export_dir.mkdir(parents=True, exist_ok=True)
293276

294-
# Get quantization config
295-
export_hf_checkpoint(
296-
model,
297-
dtype=torch.bfloat16,
298-
export_dir=export_dir,
299-
save_modelopt_state=False,
300-
is_fsdp2=True,
301-
accelerator=accelerator,
302-
)
277+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, torch.bfloat16)
278+
279+
if accelerator.is_main_process:
280+
# Save hf_quant_config.json for backward compatibility
281+
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
282+
json.dump(hf_quant_config, file, indent=4)
283+
284+
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
285+
286+
# Save model
287+
model.save_pretrained(export_dir, state_dict=post_state_dict, save_modelopt_state=False)
288+
289+
original_config = f"{export_dir}/config.json"
290+
config_data = {}
303291

304-
# Update config with quantization info
305-
config_path = export_dir / "config.json"
306-
with open(config_path) as f:
307-
config_data = json.load(f)
292+
with open(original_config) as file:
293+
config_data = json.load(file)
308294

309-
# Update architectures with original architecture. FSDP prefix must be removed for FSDP wrapped models.
310-
config_data["architectures"] = architectures
295+
config_data["quantization_config"] = hf_quant_config
296+
# Update config architectures to use original architectures that does not have FSDP prefix
297+
config_data["architectures"] = architectures
311298

312-
with open(config_path, "w") as f:
313-
json.dump(config_data, f, indent=4)
299+
with open(original_config, "w") as file:
300+
json.dump(config_data, file, indent=4)
314301

315302

316303
def main(args):
@@ -320,9 +307,9 @@ def main(args):
320307
raise OSError("GPU is required for quantization.")
321308

322309
# Validate quantization format
323-
if args.qformat not in SUPPORTED_QFORMATS:
310+
if args.qformat not in QUANT_CFG_CHOICES:
324311
raise ValueError(
325-
f"Quantization format {args.qformat} not supported. Choose from: {SUPPORTED_QFORMATS}"
312+
f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}"
326313
)
327314

328315
# Set random seeds

modelopt/torch/export/unified_export_hf.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def _export_hf_checkpoint(
495495

496496
if is_fsdp2:
497497
assert accelerator is not None, "Accelerator is required for FSDP2 export"
498+
# Gather state_dict from all ranks
498499
quantized_state_dict = accelerator.get_state_dict(model)
499500
else:
500501
quantized_state_dict = model.state_dict()
@@ -515,8 +516,6 @@ def export_hf_checkpoint(
515516
dtype: torch.dtype | None = None,
516517
export_dir: Path | str = tempfile.gettempdir(),
517518
save_modelopt_state: bool = False,
518-
is_fsdp2: bool = False,
519-
accelerator: Accelerator | None = None,
520519
):
521520
"""Exports the torch model to unified checkpoint and saves to export_dir.
522521
@@ -538,9 +537,7 @@ def export_hf_checkpoint(
538537
return
539538

540539
try:
541-
post_state_dict, hf_quant_config = _export_hf_checkpoint(
542-
model, dtype, is_fsdp2, accelerator
543-
)
540+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
544541

545542
# Save hf_quant_config.json for backward compatibility
546543
with open(f"{export_dir}/hf_quant_config.json", "w") as file:

0 commit comments

Comments
 (0)