Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def main(args):
use_seq_device_map=args.use_seq_device_map,
attn_implementation=args.attn_implementation,
)
# Store original model path for config restoration
model._original_model_path = args.pyt_ckpt_path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this above if-else instead of having it in both if and else conditions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, moved it below if-else.

else:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
Expand All @@ -270,6 +272,8 @@ def main(args):
args.pyt_ckpt_path,
**model_kwargs,
)
# Store original model path for config restoration
model._original_model_path = args.pyt_ckpt_path
calibration_only = True
model_is_already_quantized = is_quantized(model)

Expand Down
48 changes: 48 additions & 0 deletions modelopt/torch/export/model_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,54 @@ def model_config_from_dict(d: dict) -> ModelConfig:
return _from_dict(config_type, d)


def restore_original_rope_scaling(config_data: dict, original_model_path: str) -> dict:
"""Restore original rope_scaling configuration if it was modified by transformers.

Some VLM models like Qwen2.5-VL have their rope_scaling configuration modified
by the transformers library during loading (e.g., from "mrope" to "default" with
additional fields). This function restores the original configuration.

Args:
config_data: The model configuration dictionary to restore
original_model_path: Path to the original model directory

Returns:
The config_data dictionary with restored rope_scaling (modified in-place)
"""
import json
import warnings
from pathlib import Path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move these imports to the top

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do a copy of the config.json and keep it the same as before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the current solution as we can't keep it the same because we need to add the quantization config into config.json.

original_config_file = Path(original_model_path) / "config.json"
if original_config_file.exists():
with open(original_config_file) as f:
raw_original_config = json.load(f)

# Check if rope_scaling was modified from mrope to default
if (
"rope_scaling" in raw_original_config
and "rope_scaling" in config_data
and raw_original_config["rope_scaling"].get("type") == "mrope"
and config_data["rope_scaling"].get("type") == "default"
and "rope_type" in config_data["rope_scaling"]
):
print(f"Restoring original rope_scaling configuration from {original_model_path}")
config_data["rope_scaling"] = raw_original_config["rope_scaling"]

# Also restore rope_scaling in text_config if it exists
if (
"text_config" in config_data
and "rope_scaling" in config_data["text_config"]
and config_data["text_config"]["rope_scaling"].get("type") == "default"
):
config_data["text_config"]["rope_scaling"] = raw_original_config["rope_scaling"]
except Exception as e:
warnings.warn(f"Could not restore original rope_scaling configuration: {e}")

return config_data


def pad_weights(weights, tp_size):
"""Returns the padded weights to tp_size."""
assert len(weights.shape) > 1
Expand Down
6 changes: 6 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .model_config_utils import restore_original_rope_scaling
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
Expand Down Expand Up @@ -541,6 +542,11 @@ def export_hf_checkpoint(
with open(original_config) as file:
config_data = json.load(file)

# Preserve original rope_scaling configuration if it was modified by transformers
original_model_path = getattr(model, "_original_model_path", None)
if original_model_path is not None:
config_data = restore_original_rope_scaling(config_data, original_model_path)

config_data["quantization_config"] = hf_quant_config

with open(original_config, "w") as file:
Expand Down